14 February 2021

Runtime optimization with eval

Recently, I was discussing with a friend how JIT1 can, at least theoretically, do things that no static compiler can do. We were discussing the nature of JITs and how they differ from traditional compilers. I argued that there is a very fine line between a JIT interpreter and the interpreter for any "dynamic" language; here is the argument I made.

One can think of a compiler as a program that reads code into some form of tree, and then applies a succession of transformations on that tree, generating successive trees. In a typed language, some of these steps may involve changing the type of the tree altogether, whereas some will "just" be rewriting some subtree or adding information while preserving the types. Various people will use various names to designate some of these trees ("abstract syntax tree" being the most common), though in practice compilers are different enough that such names are hard to define in a generic way. Regardless of types or names, conceptually we have a collection of trees and, at some point, the compiler decides it has done enough processing and it now is in possession of a tree that is ready for evaluation.

At this point, the tree will either be evaluated in-process, in which case the conceptual compiler above is really just the preprocessing part of an interpreter, or the tree will be written to a file for later evaluation, in which case we are dealing with a more traditional compiler.

From this perspective, most interpreters are composed of both a compiler and an evaluator. In some rare cases you could argue that you have an evaluator without an embedded compiler, but we'll ignore this case for now, as even modern CPUs don't really fit that description anymore. For this reason, we're also going to ignore the distinction between "machine" code and other forms of codes: CPUs are really just another interpreter.2

Further, a dynamic language is one that has an eval function in its traditional sense: a function that takes a string, runs it through all of the tree transformations of the embedded compiler, and inserts the result into the tree currently being evaluated. A JIT interpreter is one where the evaluator can, while evaluating the tree, make transformations to the tree in order to improve performance. A traditional ahead-of-time interpreter for a static language is one where, once evaluation begins, no change to the tree being evaluated can be done.

So I think of the underlying machinery for a JIT as being the same one as the one that enables dynamic languages; the main difference is how that tree transformation is triggered.

Here is a concrete example that should help to illustrate how close those two concepts are, as well as why they can both beat static compilers. Over the past few days I have been trying to optimize my solutions to the Advent of Code 2020. Without digging into too much detail, on day 17, we need to compute the neighbourhood of a point in discrete space, for three- and four-dimensional spaces. I could have just written out the functions for each dimensionality, but I wanted to write something generic.

In Clojure, we can represent a point in space as a vector of integer: [1 2 3]. In 1d space, a neighbourhood function could be written as:

(defn neighbourhood
  [[(inc x)] [(dec x)]])

such that:

=> (neighbourhood [0])
[[-1] [1]]

In 2d space, we could write:

(defn neighbourhood
  [[x y]]
  [[x (inc y)]
   [x (dec y)]
   [(inc x) y]
   [(inc x) (inc y)]
   [(inc x) (dec y)]
   [(dec x) y]
   [(dec x) (inc y)]
   [(dec x) (dec y)]])

You can probably guess why I wasn't thrilled to write the 3d and 4d versions of that by hand. (Hint: 26 and 80 elements, respectively.) The logic can be generalized quite easily to:

(defn neighbourhood
  (->> (range (count v))
       (reduce (fn [n idx]
                 (mapcat (fn [t]
                            (update t idx inc)
                            (update t idx dec)])

However, this is a lot less efficient. In the 1d and 2d versions, the compiler can immediately see the length of the vector it has to allocate, and, beyond the initial input destructuring, all the operations are pure numerical operations on numbers. In the generalized version, there are many intermediate sequences being created.

In the context of this problem, we're only dealing with one dimensionality at a time. This is where the dynamic aspect comes in. If the dimensionality were fixed at compile-time, a "smart enough" compiler could conceivably optimize the neighbourhood function just for that size. However, in this case, the code is expected to, at runtime, read a file, and discover the dimensionality from the input data in that file. The dimensionality will be fixed for the whole processing of that one file, but we may process multiple files over a single run, each with a potentially different dimensionality.

There is no way for a static compiler to generate all the possible dimensionalities in advance.

For each instance of the problem (i.e. each file read), the neighbourhood function is going to be called hundreds of thousands of times (always with the same dimensionality), and actual profiling showed it to be responsible for a significant part of the run time. (Always profile before optimizing!) Thus, optimizing it does matter.

We could try to isolate the dimension-dependent parts of the code with something like:

(defn mk-neighbourhood
  (fn [v]
    (->> (range n)
         (reduce (fn [acc idx]
                   (mapcat (fn [t]
                              (update t idx inc)
                              (update t idx dec)])

Then, hope that in a function that looks like:

(let [dimensionality ...
      neighbourhood (mk-neighbourhood dimensionality)]
  ... code that calls neighbourhood thousands of times ...)

the JIT would be able to take advantage of the fact that n is a constant. At this point, if the JIT was using something like normalization by evaluation, it could eventually generate the code we would have handwritten for that specific value of n.

In this specific case, the JIT in the JVM is not that smart, but Clojure is a dynamic language and thus has eval. We can therefore manually write what the code should look like to the JIT once n is fixed:

(defn mk-neighs [n]
  (let [args (vec (repeatedly n gensym))]
    `(fn [~args]
       [~@(->> (range n)
               (reduce (fn [acc idx]
                         (mapcat (fn [t]
                                    (update t idx (fn [s] `(inc ~s)))
                                    (update t idx (fn [s] `(dec ~s)))])

We can test that this generates the expected code:

t.day17=> (mk-neighs 1)
  [[(clojure.core/inc G__57757)]
   [(clojure.core/dec G__57757)]])
t.day17=> (mk-neighs 2)
  [[G__57760 G__57761]]
  [[G__57760 (clojure.core/inc G__57761)]
   [G__57760 (clojure.core/dec G__57761)]
   [(clojure.core/inc G__57760) G__57761]
   [(clojure.core/inc G__57760) (clojure.core/inc G__57761)]
   [(clojure.core/inc G__57760) (clojure.core/dec G__57761)]
   [(clojure.core/dec G__57760) G__57761]
   [(clojure.core/dec G__57760) (clojure.core/inc G__57761)]
   [(clojure.core/dec G__57760) (clojure.core/dec G__57761)]])

Not as human-readable as the hand-written versions, between the fully-namespaced functions and the generated variable names, but ultimately identical as far as the Clojure compiler is concerned.

And, finally, we can look at the performance difference using the criterium library:

t.day17=> (require '[criterium.core :as crit])
t.day17=> (defmacro b
            `(first (:mean (crit/benchmark ~e {}))))
t.day17=> (def partial-neigh3 (mk-neighbourhood 3))
t.day17=> (def eval-neigh3 (eval (mk-neighs 3)))
t.day17=> (= (neighbourhood [1 2 3])
             (partial-neigh3 [1 2 3])
             (eval-neigh3 [1 2 3]))
t.day17=> (b (neighbourhood [1 2 3]))
t.day17=> (b (partial-neigh3 [1 2 3]))
t.day17=> (b (eval-neigh3 [1 2 3]))

A 10x improvement is nothing to sneeze at. In this case, further profiling showed that this was enough to move this function off the bottleneck.

  1. JIT generally stands for "Just In Time"; in this context, it refers to "just-in-time compilation", a process by which an interpreter can "compile" code at runtime. This is in opposition to "static" or "ahead-of-time" compilation, where the code is analyzed separately and all code generation is done before code starts executing.

  2. Yes, there are also trees and tree transformations going on inside modern CPUs.

Tags: clojure performance