28 November 2021

Primes in Clojure part 2: Interop

A couple weeks ago I published a blog post on computing prime numbers with Clojure. I got some good feedback on it, including pointers to other implementation techniques I'd overlooked.

In this post, I want to explore some of those, notably two based on the Java standard library.

Recap

First off, let's recap where we left off in the first post: we had two approaches, one based on checking each integer for primality, and the other, based on the Sieve of Eratosthenes, based on eliminating non-primes so that we're left with only prime numbers (without having to check them).

In terms of implementation, here's what they might look like. First, the primality check on each number:

(defn check-primality
  []
  (let [p (promise)]
    (deliver p (cons 2
                     (->> (iterate #(+ 2 %) 3)
                          (remove (fn [n]
                                    (->> @p
                                         (take-while #(<= (* % %) n))
                                         (some #(zero? (rem n %)))))))))
    @p))

We want to produce a list of all prime numbers, so it has to be lazy. In this case, both iterate and remove are lazy. We know from the commutativity of multiplication that we don't need to check for divisors larger than the square root of the number we're testing if there were no divisor smaller than it (take-while); we also know that we don't need to check for non-prime divisors, which is why we need to introduce a promise there to allow for the remove operation to refer to the list of primes we are in the process of constructing. This does not result in an infinite loop because, at run time, the remove operation will never consume more primes than we've already produced (thanks to the take-while condition). Finally, we use some to stop checking for divisors as soon as we find one.

Second, we had an approach based on the Sieve of Eratosthenes:

(defn sieve-based
  [sieve]
  (letfn [(helper [known-primes already-produced sieve-bound]
            (concat (drop already-produced known-primes)
                    (lazy-seq
                      (loop [new-bound (* 10 sieve-bound)]
                        (let [new-primes (sieve new-bound)]
                          (if (> (count new-primes) (count known-primes))
                            (helper new-primes
                                    (count known-primes)
                                    new-bound)
                            (recur (* 10 new-bound))))))))]
    (helper (sieve 10) 0 10)))

(defn array-sieve
  [bound]
  (let [candidates (long-array bound)]
    ;; initialize array to contain [0 0 2 3 4 5 ...]
    (loop [idx 2]
      (when (< idx bound)
        (aset-long candidates idx idx)
        (recur (inc idx))))
    ;; eliminate known non-primes
    (loop [idx 2]
      (when (< idx bound)
        (let [next-candidate (aget candidates idx)]
          (when-not (zero? next-candidate)
            (loop [update-idx (* idx idx)]
              (when (< update-idx bound)
                (aset-long candidates update-idx 0)
                (recur (+ idx update-idx)))))
          (recur (inc idx)))))
    ;; return a Clojure vector with only primes in it
    (vec (remove zero? candidates))))

This is a bit more code, but the principle is fairly simple: we start with a chalkboard where we've written down all of the naturals starting at 2 and up to some limit (because chalkboards in Ancient Greece were not infinite), then declare 2 a prime, and remove all the multiples of 2 from the board. We don't need any test for that, we just increment our index by 2 each time. Then we take the next number still on the board (in our implementation, the next non-zero number), and keep going like that. We start the "erasure" process at (* n n) instead of (* 2 n) because, by construction, we've already eliminated all (* m n) numbers for any (< m n). It's very annoying that we have to pick a bound for the array-based sieve method to work, and there's going to be some repeated work as we increase the bound, but it's still very fast despite that.

Note: In my previous post, I screwed up the implementation of generate-sieve by hardcoding sieve-upto (instead of using the argument sieve-fn). That makes the switch from iterate to an explicit lazy seq arguably unnecessary, but I still like this form better.

Benchmarking

Just playing with various ways to define the prime number computation is fun, but I find it a lot more interesting to compare them for speed. As a first approximation, we could just time them, as I did in my previous post, but this time I'd like to go for something a bit more principled, and will use the criterium library instead, running in a standalone application (rather than directly at the REPL, which disables some JVM optimizations).

(defmacro bench
  [primes-fn]
  `(doseq [n# [100000 300000 1000000]]
     (println (format "%-28s %7d: %5.2f"
                      (pr-str '~primes-fn)
                      n#
                      (->> (crit/benchmark (nth ~primes-fn n#) {})
                           :mean first)))))

BigInteger

One possibly useful method from the Java standard library is BigInteger's isProbablePrime method. It's documented as:

Returns true if this BigInteger is probably prime, false if it's definitely composite. If certainty is ≤ 0, true is returned.

Given that this is a probabilistic approach, with a given certainty, I would expect this to be faster, for large non-primes, than actually checking for all divisors. And because it is guaranteed to return true for primes, we can use it as a filter with no semantic change to our existing check-based implementation:

(defn check-based-with-bigint
  [certainty]
  (let [p (promise)]
    (deliver p (cons 2
                     (->> (iterate #(+ 2 %) 3)
                          (filter (fn [n]
                                    (.isProbablePrime
                                      (java.math.BigInteger. n)
                                      certainty)))
                          (remove (fn [n]
                                    (->> @p
                                         (take-while #(<= (* % %) n))
                                         (some #(zero? (rem n %)))))))))
    @p))

Does it make it any faster?

t.core=> (time (nth (check-based) 1000000))
"Elapsed time: 107905.383266 msecs"
15485867
t.core=> (time (nth (check-based-with-bigint 2) 1000000))
"Elapsed time: 110825.211513 msecs"
15485867
t.core=> (time (nth (check-based-with-bigint 5) 1000000))
"Elapsed time: 119101.514406 msecs"
15485867
t.core=>

Oh well. They can't all be winners.

BitSet

Another possibly-useful class in the Java standard library is BitSet: in our Sieve implementation, we only really need to keep track of which elements are crossed; the actual value of the element is not useful because, if it's not zero, it's always equal to its index. So why bother storing an integer at all? We could just store a bit for that. That's exactly what BitSet lets us do.

Here is the sieve implemented on top of BitSet instead of an array:

(defn bitset-sieve
  [^long bound]
  (let [candidates (java.util.BitSet. bound)]
    ;; initialize array to contain [0 0 2 3 4 5 ...]
    (.set candidates 2 bound)
    ;; eliminate known non-primes
    (loop [idx 2]
      (when (< idx bound)
        (when-let [next-candidate (.get candidates idx)]
          (loop [update-idx (* idx idx)]
            (when (< update-idx bound)
              (.clear candidates update-idx)
              (recur (+ idx update-idx)))))
        (recur (inc idx))))
    ;; return a Clojure seq with only primes in it
    (take-while pos? (iterate #(.nextSetBit candidates (inc %)) 2))))

If we compare this to the array-based implementation, it's really not more complicated, just using a slightly less common API. How does it do?

t.core=> (time (last (array-sieve 10000000)))
"Elapsed time: 4214.152658 msecs"
9999991
t.core=> (time (last (bitset-sieve 10000000)))
"Elapsed time: 391.648236 msecs"
9999991
t.core=>

So this one does seem promising. Let's now run a real benchmark with all that.

Benchmark results

Here's the benchmark we're running, based on the definitions above:

(defn -main
  [& args]
  (bench (check-based))
  (bench (check-based-with-bigint 2))
  (bench (check-based-with-bigint 5))
  (bench (sieve-based array-sieve))
  (bench (sieve-based bitset-sieve)))

and here's the result on my machine:

$ java -server -jar target/uberjar/t-app-standalone.jar
(check-based)                 100000:  1.63
(check-based)                 300000:  7.73
(check-based)                1000000: 46.26
(check-based-with-bigint 2)   100000:  1.99
(check-based-with-bigint 2)   300000:  9.20
(check-based-with-bigint 2)  1000000: 47.43
(check-based-with-bigint 5)   100000:  2.32
(check-based-with-bigint 5)   300000:  9.77
(check-based-with-bigint 5)  1000000: 49.19
(sieve-based array-sieve)     100000:  2.87
(sieve-based array-sieve)     300000:  2.80
(sieve-based array-sieve)    1000000: 29.51
(sieve-based bitset-sieve)    100000:  0.18
(sieve-based bitset-sieve)    300000:  0.20
(sieve-based bitset-sieve)   1000000:  5.18
$

I must say I did not expect these numbers to be quite so much lower than the REPL ones, but the overall rankings seem to hold.

Conclusion

While the BigInteger route did not yield any fruit, the BitSet class gave us a very significant speed boost, at the cost of a very minor, local change to the code of array-sieve. While this post deals with an eminently useless problem, there's a very practical, broader lesson here. Clojure being built on the JVM was a very deliberate choice, and it was not just done to avoid having to reimplement a garbage collector. Integration with the JVM and its ecosystem is a cornerstone of Clojure's design, and it's often worth looking around in existing Java libraries for solutions to problems encountered in Clojure.

Tags: primes clojure