16 May 2021

Monads, part five: What if I'm not writing Haskell?

In the past two posts of this series, I have shown what I believe to be a simple, effective technique for creating monads in Haskell. At this point one may be wondering how that translates to other languages, if indeed it applies at all.

In this post, I will take a look at what exactly we need from a language in order to apply this "structural monad" approach, with a detailed example in Clojure.

Requirements

Let's first look at what we need to be able to use this approach at all. There are three steps to consider:

  1. Defining the structure of a monad.
  2. Defining a run function that assigns the correct meaning to a monadic structure.
  3. Being able to write specific instances of monadic structures.

To define the structure, the first thing we need is a way to create a structure for the various monadic actions. This could be a type, as we did in Haskell; this could be a set of classes, as we did in Java. All we need here is really just tagged collections of values; any way to represent that can work, from actual types to tuples, vectors, lists, or associative collections.

We also need support for first-class functions: at the very least, the structure representing a bind action will need to embed a function. In order to build up specific instances of monadic structures, we also need some form of lambda (i.e. literal, anonymous function) syntax, otherwise we will not be able to build up nested bind calls and therefore we won't be able to actually bind the results of previous steps. And without that, we can't even express the third monadic law.

Finally, whatever means of tagging individual monadic actions we chose when defining the structure needs to be "switchable on" in the run function. In our Haskell examples, we used pattern matching to that effect; in our Java examples, we used inheritance, tough we could also have gone for instanceof calls. Or a switch on .getClass().getSimpleName(). If we represent individual monadic actions as associative data structures, say objects in JavaScript, we could just have a tag key with a string, and match on that.

That's pretty much it, really: some way to group values, some way to tag those groups, and lambdas.

Nice-to-haves

Implementing and using monads in JavaScript is doable; it has all of the requirements. It may not be very pleasant, however, nor very (syntactically) lightweight. There are a number of things a language can have that will make this process easier. Here are some of them.

Types are the most obvious one. Types can help with any piece of code, but using monads usually means juggling slightly more complex types than one would usually juggle, and the more complex types you use, the more opportunities you have to screw up and the more value you get from a type system. Note that this does not necessarily mean "a type system that can express Monad"; as we saw in part one, the Java type system cannot express a Monad interface, yet it can still be leveraged to ensure that monadic values are properly constructed for a given monad.

Generalized tail call elimination is also a big plus. In simple cases monadic structures will not be very deeply nested, so this may not be a very apparent problem at first, but the most natural expression for the run method will often be a recursive one. It is always possible to rewrite recursive code to avoid recursion, but in the case of a monadic run method it would add accidental complexity to what may already be considered as non-straightforward code.

Pattern matching can make writing the run method a lot cleaner. Alternatives like switch statements and inheritance can work, but switch statements will tend to be a bit more clunky as they will require some ad-hoc way to extract components in each branch, and inheritance will have the effect of splitting what should be a single function's logic across many classes, making it a bit difficult to keep track of the overall behaviour.

Finally, of course, having special syntax for building up monadic values is also a big plus. As we saw in our Java examples (mostly in parts one and two), it is possible to build up monadic structures simply by nesting literal bind calls (assuming we have lambdas), but it gets a bit tedious.

Clojure example: pre-monad code

I did not set out to write that code with a monad. At all. I'm still learning monads, you see. That's the main reason I wrote this series: help me clarify my own understanding. In parallel, I was writing some code to explore how interpreters work. I'll write more on that in future posts.

And as I was writing that code, and this series of posts, it dawned on me that the code I was trying to refactor could really benefit from a monad. So here we go. For context, this code is a compiler from "ast-like" tree represensation of a custom C-like programming language to a (custom-defined) register-machine-style list of opcodes. Concretely, I want to transform this:

(def ast
  [:do
   [:set 0 [:lit 100]]
   [:set 1 [:lit 1000]]
   [:while [:not= [:lit 0] [:var 1]]
    [:do
     [:set 0 [:add [:add [:add [:var 0] [:lit 4]] [:var 0]] [:lit 3]]]
     [:set 0 [:add [:add [:var 0] [:lit 2]] [:lit 4]]]
     [:set 1 [:add [:lit -1] [:var 1]]]]]
   [:return [:var 0]]])

into this:

(def register-code
  [[:load 0 100]
   [:load 1 1000]
   [:load 4 0]
   [:not= 5 4 1]
   [:jump-if-zero 5 17]
   [:load 7 4]
   [:add 8 0 7]
   [:add 9 8 0]
   [:load 10 3]
   [:add 0 9 10]
   [:load 12 2]
   [:add 13 0 12]
   [:load 14 4]
   [:add 0 13 14]
   [:load 16 -1]
   [:add 1 16 1]
   [:jump 2]
   [:return 0]])

The code I had for this transformation when I decided to use a monad looked like this:

(defn compile-register-ssa
  [ast]
  (let [max-var ((fn max-var [[op & [arg1 arg2 :as args]]]
                   (case op
                     :lit 0
                     :return (max-var arg1)
                     :not= (max (max-var arg1)
                                (max-var arg2))
                     :add (max (max-var arg1)
                               (max-var arg2))
                     :var arg1
                     :set (max arg1 (max-var arg2))
                     :do (reduce max (map max-var args))
                     :while (max (max-var arg1)
                                 (max-var arg2)))) ast)
        r (fn [i t] (or t (+ i max-var 1)))
        h (fn h [m [op & [arg1 arg2 :as args]]]
            (case op
              :return (let [[r m] (h (assoc m :ret nil) arg1)]
                        [nil (update m :code concat [[:return r]])])
              :lit [(r (-> m :code count) (:ret m)) (update m :code concat [[:load (r (-> m :code count) (:ret m)) arg1]])]
              :set (if (#{:lit :add} (first arg2))
                     (h (assoc m :ret arg1) arg2)
                     (let [[r m] (h (assoc m :ret nil) arg2)]
                       [nil
                        (-> m :code concat [[:loadr arg1 r]])]))
              :do [nil (->> args
                            (reduce (fn [m el] (second (h (assoc m :ret nil) el)))
                                    m))]
              :while (let [start (-> m :code count)
                           [rcond m] (h (assoc m :ret nil) arg1)
                           m (update m :code concat [[:jump-if-zero rcond nil]])
                           jump-idx (-> m :code count dec)
                           [_ m] (h (assoc m :ret nil) arg2)
                           m (update m :code concat [[:jump start]])]
                       [nil (update m :code (fn [code]
                                              (update (vec code) jump-idx assoc 2 (count code))))])
              :var [arg1 m]
              :add (let [ret (:ret m)
                         [rleft m] (h (assoc m :ret nil) arg1)
                         [rright m] (h (assoc m :ret nil) arg2)
                         rresult (-> m :code count)]
                     [(r rresult ret)
                      (update m :code concat [[:add (r rresult ret) rleft rright]])])
              :not= (let [ret (:ret m)
                          [rleft m] (h (assoc m :ret nil) arg1)
                          [rright m] (h (assoc m :ret nil) arg2)
                          rresult (-> m :code count)]
                      [(r rresult ret)
                       (update m :code concat [[:not= (r rresult ret) rleft rright]])])))]
    (-> (h {:ret nil, :code []} ast)
        second
        :code
        vec)))

There's a lot of code, and a lot going on there. The main takeaway, though, is that there is some state that is being threaded through the whole computation, including multiple recursive calls in the case of :add and :not=. Let's take a look at :while, for example:

:while (let [start (-> m :code count)
             [rcond m] (h (assoc m :ret nil) arg1)
             m (update m :code concat [[:jump-if-zero rcond nil]])
             jump-idx (-> m :code count dec)
             [_ m] (h (assoc m :ret nil) arg2)
             m (update m :code concat [[:jump start]])]
         [nil (update m :code (fn [code]
                                (update (vec code) jump-idx assoc 2 (count code))))])

We are threading through two pieces of state:

  • the state of the loop itself in the form of ret, which is a switch for the code-generation loop that tells it whether the code being generated needs to allocate a new register for its result (when ret is nil), or instead write its result to a given, preexisting register (when ret is a number).
  • the state of the generated code in the form of code, which is represented here as a vector of opcodes, which is also used to determine the number of the next register to allocate.

So start is the position we're at in the code being generated; we save it so we can jump to it at the end of the while body. But the same expression (-> m :code count) in :add is used to generated a new register number. This is confusing.

I did not decide to add a monad to this code just because I could. I wanted to make three changes that I thought would be easier to do, overall, by first transforming this code to a monadic form, then making the changes, than by trying to do the changes directly on this mess. These changes were:

  • Decouple the "code" state from the "loop" state (ret).
  • Decouple register numbers from length of code generated.
  • Lift constants out of the code: there is little point in reassigning constants to registers in the body of the inner loop.

Monads in Clojure

This code really looked like it could use a monadic approach, but I was not writing in Haskell, and Clojure has no support for monads. It's a dynamically typed language with an emphasis on runtime values and a definite desire to stay away from category theory. How would one go about adding support for monads?

Going back to our list of requirements above, let's first look at how to represent a monadic structure. Clojure has literal syntax for vectors, which are commonly used for tree-like structures in what is known in the community as "the hiccup notation": each node of the tree is a vector whose first element is a keyword, and where the other elements are whatever children or attributes nodes of that type (indicated by the keyword) have. Keywords are a type of value that essentially just exists and can only be compared for equality.

This suggests a natural form for our monadic values. Clojure does not define types, so we need to look at a concrete example; let's pick this one from part two of this series, in the context of a Stack monad. In Haskell:

do
  a <- pop
  b <- pop
  push (a + b)

which in our "structural monads" approach from part three turns into:

Bind Pop (\a -> Bind Pop (\b -> Push (a + b)))

A direct equivalent in Clojure hiccup-like notation would look like:

[:bind [:pop] (fn [a] [:bind [:pop] (fn [b] [:push (+ a b)])])]

Or, with a little bit of whitespace:

[:bind    [:pop]
   (fn [a]
[:bind    [:pop]
   (fn [b]
          [:push (+ a b)]
)])]

This could work, and be about as useable as the Java code was. So that's the production side. On the consuming side, i.e. the run method, we can switch on the keywords, then destructure once we know how many arguments we expect:

(defn run [stack ma]
  (case (first ma)
    :pure (let [[_ a] ma]
            [stack a])
    :bind (let [[_ ma f] ma]
            (let [[stack a] (run ma)]
              (run stack (f a))))
    :pop [(pop stack) (peek stack)]
    :push (let [[_ a] ma]
            [(conj stack a), nil])))

This can work, but there is a lot of extra noise on both sides. Highly regular noise, which means we can exploit the fact that we're using a lisp and make up a macro. Or, rather, here, two macros. We'll keep both of them very simple, with little to no error handling and definitely no bells and whistles.

First, we can make a macro that creates the intermingled nested vectors-and-functions that represent a monadic value. The pattern here is very simple, all we need is a name to bind to and an expression that will become the body of the function. The following macro:

(defmacro mdo
  [bindings]
  (let [[n v & r] bindings]
    (if (empty? r)
      v
      [:bind v `(fn [~n] (mdo ~r))])))

will turn

(mdo [a [:pop]
      b [:pop]
      _ [:push (+ a b)]])

into:

[:bind [:pop] (fn [a] [:bind [:pop] (fn [b] [:push (+ a b)])])]

It's not quite as nice as the Haskell notation, and doesn't support the "non-binding" syntax (we have to explicitly bind _), but it's very small and simple, and that counts for a lot with macros, as debugging them can be a bit of a pain.

On the other side, we have the destructuring. Clojure does not have pattern matching, which forces us to first test for the type of node, and then destructure it. We can, however, (syntactically) reorder these things using another simple macro:

(defmacro match
  [expr & cases]
  (let [e (gensym)]
    `(let [~e ~expr]
       (case (first ~e)
         ~@(->> (partition 2 cases)
                (mapcat (fn [[pat body]]
                          [(first pat) `(let [~(vec (cons '_ (rest pat))) ~e]
                                          ~body)])))))))

It's a little bit less simple, but it's not all that complex either. It will transform this:

(defn run [stack ma]
  (match ma
    [:return a] [stack a]
    [:bind ma f] (let [[stack a] (run stack ma)]
                   (run stack (f a)))
    [:pop] [(pop stack) (peek stack)]
    [:push a] [(conj stack a) nil]))

into (roughly):

(defn run [stack ma]
  (let [G__2036 ma]
    (case (first G__2036)
      :return (let [[_ a] G__2036]
                [stack a])
      :bind (let [[_ ma f] G__2036]
              (let [[stack a] (run stack ma)]
                (run stack (f a))))
      :pop (let [[_] G__2036]
             [(pop stack) (peek stack)])
      :push (let [[_ a] G__2036]
              [(conj stack a) nil]))))

The outer let is there as an optimization (assuming pure code) to ensure that the expression we match on is evaluated once, and not twice.

Note that, because Clojure is dynamically typed, both of these macros will work equally well for any monadic structure. (match will actually work for any hiccup-like tree.) However, when writing the run function, you still have to know which monad you're targeting as you need to cover all the relevant cases. Not having to convince the compiler about which type a value has doesn't mean your code will magically work if you give it the wrong type at runtime.

Monadic compiler

So let's get back to our compiler. What monadic operations do we need? First, let's get the ret option out of the way: this is not part of the state we want to track with the monad; this is part of the state we want to track in the loop itself. So we'll make it an explicit argument to the recursive calls.

In the monadic state, we want to keep the code we're producing. We're obviously going to need :pure and :bind (I chose to avoid calling the monadic no-op :return to avoid confusion with the :return in the input ast). What other operations do we need?

  • We need to be able to :emit new code, taking a single opcode and returning the register we may have written to. This last bit should change at some point, but we're not changing this yet: this is a refactoring, we're not changing behaviour.
  • We need to be able to query the state for our :current-position, so we can :emit code that jumps to it.
  • We need to get a :free-register.
  • Perhaps less obviously, we need a way to insert a placeholder, blank opcode now in order to :resolve it in the :future, such as when we are done compiling the body of a :while loop and need to go back and write down the :jump instruction to skip it.

This lead me to the following run function:

(defn run [s ma]
  (match ma
    [:pure a] [s a]
    [:bind ma f] (let [[s a] (run s ma)]
                   (run s (f a)))
    [:current-position] [s (-> s :code count dec)]
    [:free-register] [s (next-reg s)]
    [:emit code] [(update s :code conj code) (next-reg s)]
    [:future] [(-> s
                 (update :code conj :placeholder)
                 (update :nested conj (-> s :code count)))
               nil]
    [:resolve f] [(-> s
                    (update :code assoc (-> s :nested peek) (f (-> s :code count)))
                    (update :nested pop))
                  nil]))

which should be run with an initial state of {:nested (), :code []}. Using this, the loop can be simplified rather mechanically to:

(fn h [op & [ret]]
  (match op
    [:return a] (mdo [r (h a)
                      _ [:emit [:return r]]])
    [:lit v] (mdo [r (if ret
                       [:pure ret]
                       [:free-register])
                   _ [:emit [:load r v]]])
    [:set v exp] (if (#{:lit :add :not=} (first exp))
                   (h exp v)
                   (mdo [r (h exp)
                         _ [:emit [:loadr v r]]]))
    [:do & exps] (if (empty? exps)
                   [:pure nil]
                   (mdo [_ (h (first exps))
                         _ (h (cons :do (rest exps)))]))
    [:while con bod] (mdo [before-condition [:current-position]
                           condition (h con)
                           _ [:future]
                           body (h bod)
                           _ [:emit [:jump (inc before-condition)]]
                           _ [:resolve (fn [pos] [:jump-if-zero condition pos])]])
    [:var r] [:pure r]
    [:add e1 e2] (mdo [left (h e1)
                       right (h e2)
                       r (if ret
                           [:pure ret]
                           [:free-register])
                       _ [:emit [:add r left right]]])
    [:not= e1 e2] (mdo [left (h arg1)
                        right (h arg2)
                        r (if ret
                            [:pure ret]
                            [:free-register])
                        _ [:emit [:not= r left right]]])))

There is definitely still some non-trivial logic in there, but I found it a lot easier to work with. Let's take another look at that :while case:

    [:while con bod] (mdo [before-condition [:current-position]
                           condition (h con)
                           _ [:future]
                           body (h bod)
                           _ [:emit [:jump (inc before-condition)]]
                           _ [:resolve (fn [pos] [:jump-if-zero condition pos])]])

Assuming one understands the semantics of the monadic operations involved, this is a lot more intention-revealing than the previous form.

Conclusion

This concludes my series of blog posts on monads, for now at least. There is one more post I want to add, but it will come much later as I first need to learn the subject matter myself.

If you've read it all, thanks. I hope it's been useful, or at least mildly entertaining, to watch me struggle through it. I also hope it's encouraged you to give monads a try, even if you're not writing in Haskell.

Tags: monad-tutorial