5 December 2021

"The Genuine Sieve of Eratosthenes" in Clojure

After my previous two posts on primes (first, second), a few people have reached out through various channel to suggest alternative ways to compute them. Eventually I ended up with a Haskell paper in my hands, and I'm apparently starting to make a habit out of translating those to Clojure, so I thought I'd give this one a try too.

There is some overlap with my previous posts, but I'm going to write this one to stand on its own as it doesn't feel right to assume every reader has read my entire blog.

What the sieve is not

The paper starts with the observation that a lot of sources seem confused about what the sieve is, and presents this code as a common example of something people call a sieve when it's actually not:

primes = sieve [2..]

sieve (p : xs) = p : sieve [x | x <- xs, x `mod` p > 0]

A direct translation to Clojure, for those unfamiliar with Haskell syntax, would yield:

(defn sieve
  [[p & xs]]
  (cons p
        (lazy-seq
          (sieve (for [x xs :when (pos? (mod x p))]
                   x)))))

(def primes (sieve (iterate inc 2)))

or, avoiding the memory leak caused by the top-level binding and switching to the more common explicit filter instead of a comprehension:

(defn unfaithful
  ([] (unfaithful (iterate inc 2)))
  ([[p & xs]] (cons p (lazy-seq (unfaithful (filter #(pos? (mod % p)) xs))))))

This doesn't work in Clojure because we keep wrapping the lazy seq in more and more filter calls (either explicitly, or implicitly through the for macro) and eventually those blow the (admittedly limited) JVM stack. On my machine, using OpenJDK 1.8.0_292, this blows up on the \(1798\)th prime. We could try to work around that by increasing the stack size, but that's a very limited fix at best, and stack size is a JVM-wide setting so changing it may have unforeseen consequences.

Haskell has a larger default stack, so that particular problem is less of an issue, but the paper goes on to explain that this approach is still very slow.

In the rest of the paper, the author refers to this implementation as "the unfaithful sieve"; I'll adopt the same name here to distinguish it from the real sieves later on.

What the sieve is

Eratosthenes did not have access to computers, so his technique was based on a blackboard. It worked as follows. First, write down all the integers you can fit on your blackboard, starting at \(2\). Then, for each number \(p\) (starting at \(2\)):

  1. \(p\) is prime.
  2. Cross off all of the multiples of \(p\) on the blackboard. You can easily do that without any division (or multiplication) by just erasing every \(p\)th number. As a minor optimization, one can start at \(p^2\) instead of at \(2p\).
  3. Find the first number after \(p\) that is not crossed. Set \(p\) to that and go back to 1.

As another optimization, if you know the maximum number you have written on the board, \(n\), when you reach a situation in which \(p \geq \sqrt{n}\), you can stop and consider all the remaining numbers to be prime.1

Why the wrong one is bad

Is it obvious to you that this performs a lot less work than the unfaithful sieve above? It wasn't immediately obvious to me, but the paper walks through a simple example: what happens for \(17\). In the unfaithful sieve, the filter function for \(17\) will get applied to every single integer that is not a composite of the primes under \(17\) (namely \(2\), \(3\), \(5\), \(7\), \(11\), and \(13\)). So, for example, we will be computing (mod 19 17) and (mod 23 17), which the "true" sieve flies right over to \(34\) even if we don't implement the square optimization.

Looking at it from another angle, how many filters are we going to apply to \(19\)? Well, all of them, because \(19\) is prime: we will compute (mod 19 p) for p in [2 3 5 7 11 13 17]. Obviously, this does not get better as primes grow and the number of primes smaller than them grows accordingly. Whereas, in the real sieve, we do nothing to \(19\) because it's prime.

For non-primes, i.e. the numbers that we do want to cross off, the unfaithful sieve will stop at the first composite found, whereas the real sieve may cross off the same number multiple times. For instance, \(18\) is divisible by both \(2\) and \(3\) and bigger than either \(2^2\) or \(3^2\), so it would get crossed twice. But, even though the unfaithful sieve stops at the first composite it finds, it still tries a lot more of them. 18 specifically will be filtered out by our first filter (as a multiple of 2), but if we look instead at 221 (\(13 \times 17\)), the unfaithful prime will compute (mod 221 p) for p in [2 3 5 7 11 13] (and stop at \(13\)). By contrast, the real sieve only crosses this number twice.

So the unfaithful sieve does a lot more checking, and that's even before you consider the difference between the actual check: a simple addition for the real sieve, and a division for the unfaithful one.

The paper goes on to calculate complexity classes; the details of the calculations are not very relevant for this blog post, so I'll just write down the results. The theoretical complexity of generating all the primes lower than \(n\) is found to be:

\[ \Theta (n \log \log n) \]

for the real sieve,

\[ \Theta\left(\frac{n^2}{(\log n)^2}\right) \]

for the unfaithful sieve, and

\[ \Theta\left(\frac{n\sqrt{n}}{(\log n)^2}\right) \]

for a slightly better implementation of the same idea as the unfaithful sieve, which would be written:

primes = 2 : [x | x <- [3..], isprime x]
isprime x = all (\p -> x `mod` p > 0) (factorsToTry x)
where factorsToTry x = takeWhile (\p -> p * p <= x) primes

or, in Clojure:

(declare primes)

(defn prime?
  [x]
  (let [factors-to-try (fn [x] (take-while #(<= (* % %) x) primes))]
    (every? #(pos? (mod x %)) (factors-to-try x))))

(def primes
  (cons 2
        (for [x (iterate inc 3)
              :when (prime? x)]
          x)))

Avoiding the global memory leak and using more idiomatic Clojure yields:

(defn trial-division
  []
  (let [primes (promise)
        prime? (fn [x]
                 (->> @primes
                      (take-while #(<= (* % %) x))
                      (every? #(pos? (mod x %)))))]
    (deliver primes (cons 2 (filter prime? (iterate inc 3))))
    @primes))

where we use a promise to break the circular dependency between primes and prime?. Crucially, while it implements the same idea as the unfaithful sieve (with the additional square root optimization), this version does not rely on an ever-growing stack of nested filter calls, and thus is stack-safe.

Comparing performance is tricky when we can only go up to \(1797\), but as a comparison:

t.core=> (time (nth (unfaithful) 1797))
"Elapsed time: 388.187915 msecs"
15383
t.core=> (time (nth (trial-division) 1797))
"Elapsed time: 30.80426 msecs"
15383
t.core=>

What about infinity?

The main drawback of the sieve approach at this point is that it is ill-suited to the task of producing an infinite list, which both other approaches handle in principle (for infintiy less than \(1798\)). I have, in my previous posts, hacked around that by recomputing the sieve on increasingly large limits (using mutable arrays or bit sets), but the paper takes the better, harder route of actually defining an infinite sieve using a purely functional approach.

The obvious issue is step 2 in the sieve algorithm above: if we're dealing with an infinite list, we can't erase all of the multiples of a prime. So, instead, we store a list of all of the "next multiple". When we consider \(3\), for example, that list only contains \(4\), so \(3\) is prime. When we consider \(4\), the list contains \(4\) and \(9\), so \(4\) is not prime. When we consider \(10\), the list contains \(10\) (from \(2\)), \(12\) (from \(3\); note that we don't yet keep track of the \(12\) we'll get from \(2\), as the counter for \(2\) is still at \(10\)) and \(25\) (from \(5\)). And so on. So we need to keep track of, for each prime we have found so far, the "next" value it would cross off. We also need to be able to efficiently ask of our list for the minimum value in it. The algorithm becomes, roughly:

  1. Start with \(n = 3\) and \(p = [4]\).
  2. If \(n < \min p\), \(n\) is prime. Consider the next \(n\), and add \(n^2\) to the list. Go to 1.
  3. If \(n = \min p\), then \(n\) is not prime. Replace each \(n\) in \(p\) with the next value for the prime that had produced it, then increment \(n\) and go to 1.

The tricky bit is step 3, where we need to know not only what the next smallest non-prime is, but also which primes it can be composed of.

The paper proposes this first implementation, using a map of "next value" to "list of primes that produced it" to keep track of provenance:

sieve xs = sieve' xs Map.empty
where
  sieve' [] table = []
  sieve' (x:xs) table = case Map.lookup x table of
    Nothing -> x : sieve' xs (Map.insert (x * x) [x] table)
    Just facts -> sieve' xs (foldl reinsert (Map.delete x table) facts)
  where
    reinsert table prime = Map.insertWith (++) (x + prime) [prime] table

A Clojure version of that would be:

(defn sieve
  ([s] (sieve s {}))
  ([[x & xs] table]
   (if-let [factors (get table x)]
     (sieve xs (reduce (fn [t prime]
                         (update t (+ x prime) concat [prime]))
                       (dissoc table x)
                       factors))
     (cons x (lazy-seq (sieve xs (assoc table (* x x) [x])))))))

primes is still defined by calling sieve on a list of integers starting at \(2\).

This implementation is much simpler than my kludgy hack around bounded, mutable sieves.

How does that look in terms of performance and how does it compare to the trial-division approach? Let's first define a simple bench function based on criterium:

(defn bench
  [primes-fn label]
  (let [results (->> [1000 3000 10000 30000 100000]
                     (map (fn [s]
                            (format "%6.3f"
                                    (-> (crit/benchmark (nth (primes-fn) s) {})
                                        :mean first))))
                     (interpose " ")
                     (apply str))]
    (println (format "[%-15s %s]" label results))))

and now let's try running that:

t.core=> (bench trial-division :trial-division)
[:trial-division  0.010  0.041  0.207  0.932  4.938]
nil
t.core=> (bench #(sieve (iterate inc 2)) :sieve)
[:sieve           0.017  0.065  0.279  1.048  4.377]
nil
t.core=>

We can see that the trial division starts out faster, due to better constant factors, but that the sieve scales better, enough to actually be faster around the hundred thousandth prime. That advantage gets better as we move further along the list:

t.core=> (time (nth (trial-division) 1000000))
"Elapsed time: 124154.074597 msecs"
15485867
t.core=> (time (nth (sieve (iterate inc 2)) 1000000))
"Elapsed time: 69140.309762 msecs"
15485867
t.core=>

Data structures

Using a map here is not very efficient: we know we're only ever going to want to look at the smallest key, so we do not need to pay the logarithmic cost of random access in a tree-based map. Even for a trie-based map where access is roughly constant (\(\log_{32}n\)) like Clojure's, the constant factors are quite high. Finally, the map we're using is tailored for any object, which means it's introducing a lot of boxing for our integers.

The paper suggests replacing the map with a priority queue, offering this implementation:

sieve [] = []
sieve (x:xs) = x : sieve' xs (insertprime x xs PQ.empty
where
  insertprime p xs table = PQ.insert (p * p) (map (* p) xs) table
  sieve' [] table = []
  sieve' (x:xs) table
    | nextComposite <= x = sieve' xs (adjust table)
    | otherwise = x : sieve' xs (insertprime x xs table)
  where
    nextComposite = PQ.minKey table
    adjust table
      | n <= x = adjust (PQ.deleteMinAndInsert n' ns table)
      | otherwise = table
    where
      (n, n':ns) = PQ.minKeyValue table

where PQ is an imaginary package that supports the following API:

  • empty returns an empty queue.
  • minKey pq returns the minimum key in the queue.
  • minKeyValue pq returns a (key, value) pair where the key is minKey pq.
  • insert k v pq inserts the value v for key k in `pq.
  • deleteMinAndInsert k v pq inserts the value v at key k and simultaneously deletes the previous minimum key. This one is an implementation-dependant optimization.

Translating this to Clojure raises the question of where we could be getting this PQ module from. The Clojure standard library has a sorted-map function, which returns a sorted map, so we'll start with that. In terms of API, it behaves like a map, except for the fact that traversals are guaranteed to be ordered by key (as opposed to the unspecified traversal order of standard Clojure maps).

(let [insert-prime (fn [table x xs]
                     (assoc table (* x x) [(map #(* x %) xs)]))]
  (defn sieve-sm
    ([[i & is]] (sieve-sm is (insert-prime (sorted-map) i is)))
    ([[x & xs] table]
     (let [[next-composite factors] (first table)]
       (if (> next-composite x) ;; x is prime
         (cons x (lazy-seq (sieve-sm xs (insert-prime table x xs))))
         (sieve-sm xs (reduce (fn [table [next-comp & future-comps]]
                                (update table next-comp
                                        conj future-comps))
                              (dissoc table next-composite)
                              factors)))))))

Unfortunately, this doesn't perform very well:

t.core=> (bench #(sieve-sm (iterate inc 2)) :sieve-sm)
[:sieve-sm        0.025  0.107  0.486  1.822  7.813]
nil
t.core=> (time (nth (sieve-sm (iterate inc 2)) 1000000))
"Elapsed time: 118290.949828 msecs"
15485917
t.core=>

So it looks like the Clojure sorted-map may not actually be an optimal choice here. But this is Clojure, which means we're on the JVM, and Java does have a (mutable) PriorityQueue, so let's try that:

(let [insert-prime (fn [^PriorityQueue table x xs]
                     (.add table [(* x x) (map #(* x %) xs)])
                     table)]
  (defn sieve-pq
    ([[i & is]] (sieve-pq
                  is
                  (insert-prime
                    (PriorityQueue. 10 (fn [[x] [y]] (< x y)))
                    i is)))
    ([[x & xs] ^PriorityQueue table]
     (let [[next-composite] (.peek table)]
       (if (> next-composite x)
         (cons x (lazy-seq (sieve-pq xs (insert-prime table x xs))))
         (do (while (== x (first (.peek table)))
               (let [[_ [f & fs]] (.poll table)]
                 (.add table [f fs])))
             (sieve-pq xs table)))))))

This performs suspiciously similar, so I'll probably find an obvious issue with the way I've written this right after publishing this post. We'll see.

Wheels

Note that the switch to a priority queue is not the only change going on here. Instead of storing pairs of (Int, [Int]) representing the primes that generated a composite, we are now storing pairs of (Int, [[Int]]) where each list in the list is the list of all future composites (created by the insertprime function) of the primes that created (or added to) the key. And because we create those lists by multiplying the input list, we can now do smart things like skipping all of the multiples of 2 altogether:

t.core=> (bench #(cons 2 (sieve-sm (iterate (partial + 2) 3))) :sieve-sm-2)
[:sieve-sm-2      0.010  0.042  0.189  0.725  3.282]
nil
t.core=> (time (nth (cons 2 (sieve-sm (iterate #(+ % 2) 3) (sorted-map))) 1000000))
"Elapsed time: 52191.14446 msecs"
15485867
t.core=>

This one does seem to work for us just as well as in the paper. Going further, if we can just skip right over the multiples of 2, why not also skip over the multiples of 3? Or 5? Or 7? Where should we stop?

It turns out, while skipping all multiples of 2 is very easy, adding more "base primes" that we skip is increasingly complex, and there are diminishing returns as the bigger the prime is, the less dense its multiples are.

The paper gives an implementation for skipping multiples of \(2\), \(3\), \(5\) and \(7\), as well as some argument for why that may be a sweet spot. The technique they use is reminiscent of a wheel turning on the number line, hence the title of this section.

Here's how that works:

wheel2357 = 2:4:2:4:6:2:6:4:2:4:6:6:2:6:4:2:6:4:6:8:4:2:4:2:4:8
            :6:4:6:2:4:6:2:6:6:4:2:4:6:2:6:4:2:4:2:10:2:10:wheel2357

spin (x:xs) n = n : spin xs (n+x)

primes = 2 : 3 : 5 : 7 : sieve (spin wheel2357 11)

Translation to Clojure is fairly straightforward:

(def wheel2357
  (cycle [2 4 2 4 6 2 6 4 2 4 6 6 2 6 4 2 6 4 6 8 4 2 4 2 4 8
          6 4 6 2 4 6 2 6 6 4 2 4 6 2 6 4 2 4 2 10 2 10]))

(defn spin
  [[x & xs] n]
  (cons n (lazy-seq (spin xs (+ n x)))))

(defn spin-primes
  [sieve-fn]
  (concat [2 3 5 7]
          (sieve-fn (spin wheel2357 11))))

and we get a nice speed boost out of it:

t.core=> (bench #(spin-primes sieve-pq) :sieve-sm)
[:sieve-sm        0.002  0.010  0.052  0.209  0.936]
nil
t.core=> (time (nth (spin-primes sieve-pq) 1000000))
"Elapsed time: 16418.24152 msecs"
15485917
t.core=>

Conclusion

I've had a lot of fun reading this paper. If you're not used to reading academic papers, this one may be a good first one (provided you have any interest in primes numbers; if you don't, thanks for reading so far despite that, I guess) as it is very well-written.

I really like the concept of this functional approach to the sieve; it feels a lot cleaner than my hacky implementation. However, the hacky bitset-based implementation is still faster on my machine, which makes me a bit sad.


  1. Incidentally, while I know the underlying math well enough that it's obvious to me once I've seen it stated, I had not thought about it myself and that optimization is not implemented in the code I presented in my previous two posts about primes.

Tags: primes clojure papers