Monads, part five: What if I'm not writing Haskell?
In the past two posts of this series, I have shown what I believe to be a simple, effective technique for creating monads in Haskell. At this point one may be wondering how that translates to other languages, if indeed it applies at all.
In this post, I will take a look at what exactly we need from a language in order to apply this "structural monad" approach, with a detailed example in Clojure.
Requirements
Let's first look at what we need to be able to use this approach at all. There are three steps to consider:
- Defining the structure of a monad.
- Defining a
run
function that assigns the correct meaning to a monadic structure. - Being able to write specific instances of monadic structures.
To define the structure, the first thing we need is a way to create a structure for the various monadic actions. This could be a type, as we did in Haskell; this could be a set of classes, as we did in Java. All we need here is really just tagged collections of values; any way to represent that can work, from actual types to tuples, vectors, lists, or associative collections.
We also need support for first-class functions: at the very least, the
structure representing a bind
action will need to embed a function. In order
to build up specific instances of monadic structures, we also need some form of
lambda (i.e. literal, anonymous function) syntax, otherwise we will not be able
to build up nested bind
calls and therefore we won't be able to actually
bind the results of previous steps. And without that, we can't even express
the third monadic law.
Finally, whatever means of tagging individual monadic actions we chose when
defining the structure needs to be "switchable on" in the run
function. In
our Haskell examples, we used pattern matching to that effect; in our Java
examples, we used inheritance, tough we could also have gone for instanceof
calls. Or a switch on .getClass().getSimpleName()
. If we represent individual
monadic actions as associative data structures, say objects in JavaScript, we
could just have a tag
key with a string, and match on that.
That's pretty much it, really: some way to group values, some way to tag those groups, and lambdas.
Nice-to-haves
Implementing and using monads in JavaScript is doable; it has all of the requirements. It may not be very pleasant, however, nor very (syntactically) lightweight. There are a number of things a language can have that will make this process easier. Here are some of them.
Types are the most obvious one. Types can help with any piece of code, but
using monads usually means juggling slightly more complex types than one would
usually juggle, and the more complex types you use, the more opportunities you
have to screw up and the more value you get from a type system. Note that this
does not necessarily mean "a type system that can express Monad
"; as we saw
in part one, the Java type system cannot express a Monad
interface,
yet it can still be leveraged to ensure that monadic values are properly
constructed for a given monad.
Generalized tail call elimination is also a big plus. In simple cases monadic
structures will not be very deeply nested, so this may not be a very apparent
problem at first, but the most natural expression for the run
method will
often be a recursive one. It is always possible to rewrite recursive code to
avoid recursion, but in the case of a monadic run
method it would add
accidental complexity to what may already be considered as non-straightforward
code.
Pattern matching can make writing the run
method a lot cleaner. Alternatives
like switch statements and inheritance can work, but switch statements will
tend to be a bit more clunky as they will require some ad-hoc way to extract
components in each branch, and inheritance will have the effect of splitting
what should be a single function's logic across many classes, making it a bit
difficult to keep track of the overall behaviour.
Finally, of course, having special syntax for building up monadic values is
also a big plus. As we saw in our Java examples (mostly in parts one and
two), it is possible to build up monadic structures simply by nesting
literal bind
calls (assuming we have lambdas), but it gets a bit tedious.
Clojure example: pre-monad code
I did not set out to write that code with a monad. At all. I'm still learning monads, you see. That's the main reason I wrote this series: help me clarify my own understanding. In parallel, I was writing some code to explore how interpreters work. I'll write more on that in future posts.
And as I was writing that code, and this series of posts, it dawned on me that the code I was trying to refactor could really benefit from a monad. So here we go. For context, this code is a compiler from "ast-like" tree represensation of a custom C-like programming language to a (custom-defined) register-machine-style list of opcodes. Concretely, I want to transform this:
(def ast
[:do
[:set 0 [:lit 100]]
[:set 1 [:lit 1000]]
[:while [:not= [:lit 0] [:var 1]]
[:do
[:set 0 [:add [:add [:add [:var 0] [:lit 4]] [:var 0]] [:lit 3]]]
[:set 0 [:add [:add [:var 0] [:lit 2]] [:lit 4]]]
[:set 1 [:add [:lit -1] [:var 1]]]]]
[:return [:var 0]]])
into this:
(def register-code
[[:load 0 100]
[:load 1 1000]
[:load 4 0]
[:not= 5 4 1]
[:jump-if-zero 5 17]
[:load 7 4]
[:add 8 0 7]
[:add 9 8 0]
[:load 10 3]
[:add 0 9 10]
[:load 12 2]
[:add 13 0 12]
[:load 14 4]
[:add 0 13 14]
[:load 16 -1]
[:add 1 16 1]
[:jump 2]
[:return 0]])
The code I had for this transformation when I decided to use a monad looked like this:
(defn compile-register-ssa
[ast]
(let [max-var ((fn max-var [[op & [arg1 arg2 :as args]]]
(case op
:lit 0
:return (max-var arg1)
:not= (max (max-var arg1)
(max-var arg2))
:add (max (max-var arg1)
(max-var arg2))
:var arg1
:set (max arg1 (max-var arg2))
:do (reduce max (map max-var args))
:while (max (max-var arg1)
(max-var arg2)))) ast)
r (fn [i t] (or t (+ i max-var 1)))
h (fn h [m [op & [arg1 arg2 :as args]]]
(case op
:return (let [[r m] (h (assoc m :ret nil) arg1)]
[nil (update m :code concat [[:return r]])])
:lit [(r (-> m :code count) (:ret m)) (update m :code concat [[:load (r (-> m :code count) (:ret m)) arg1]])]
:set (if (#{:lit :add} (first arg2))
(h (assoc m :ret arg1) arg2)
(let [[r m] (h (assoc m :ret nil) arg2)]
[nil
(-> m :code concat [[:loadr arg1 r]])]))
:do [nil (->> args
(reduce (fn [m el] (second (h (assoc m :ret nil) el)))
m))]
:while (let [start (-> m :code count)
[rcond m] (h (assoc m :ret nil) arg1)
m (update m :code concat [[:jump-if-zero rcond nil]])
jump-idx (-> m :code count dec)
[_ m] (h (assoc m :ret nil) arg2)
m (update m :code concat [[:jump start]])]
[nil (update m :code (fn [code]
(update (vec code) jump-idx assoc 2 (count code))))])
:var [arg1 m]
:add (let [ret (:ret m)
[rleft m] (h (assoc m :ret nil) arg1)
[rright m] (h (assoc m :ret nil) arg2)
rresult (-> m :code count)]
[(r rresult ret)
(update m :code concat [[:add (r rresult ret) rleft rright]])])
:not= (let [ret (:ret m)
[rleft m] (h (assoc m :ret nil) arg1)
[rright m] (h (assoc m :ret nil) arg2)
rresult (-> m :code count)]
[(r rresult ret)
(update m :code concat [[:not= (r rresult ret) rleft rright]])])))]
(-> (h {:ret nil, :code []} ast)
second
:code
vec)))
There's a lot of code, and a lot going on there. The main takeaway, though, is
that there is some state that is being threaded through the whole computation,
including multiple recursive calls in the case of :add
and :not=
. Let's
take a look at :while
, for example:
:while (let [start (-> m :code count)
[rcond m] (h (assoc m :ret nil) arg1)
m (update m :code concat [[:jump-if-zero rcond nil]])
jump-idx (-> m :code count dec)
[_ m] (h (assoc m :ret nil) arg2)
m (update m :code concat [[:jump start]])]
[nil (update m :code (fn [code]
(update (vec code) jump-idx assoc 2 (count code))))])
We are threading through two pieces of state:
- the state of the loop itself in the form of
ret
, which is a switch for the code-generation loop that tells it whether the code being generated needs to allocate a new register for its result (whenret
isnil
), or instead write its result to a given, preexisting register (whenret
is a number). - the state of the generated code in the form of
code
, which is represented here as a vector of opcodes, which is also used to determine the number of the next register to allocate.
So start
is the position we're at in the code being generated; we save it so
we can jump to it at the end of the while
body. But the same expression (-> m :code count)
in :add
is used to generated a new register number. This is
confusing.
I did not decide to add a monad to this code just because I could. I wanted to make three changes that I thought would be easier to do, overall, by first transforming this code to a monadic form, then making the changes, than by trying to do the changes directly on this mess. These changes were:
- Decouple the "code" state from the "loop" state (
ret
). - Decouple register numbers from length of code generated.
- Lift constants out of the code: there is little point in reassigning constants to registers in the body of the inner loop.
Monads in Clojure
This code really looked like it could use a monadic approach, but I was not writing in Haskell, and Clojure has no support for monads. It's a dynamically typed language with an emphasis on runtime values and a definite desire to stay away from category theory. How would one go about adding support for monads?
Going back to our list of requirements above, let's first look at how to represent a monadic structure. Clojure has literal syntax for vectors, which are commonly used for tree-like structures in what is known in the community as "the hiccup notation": each node of the tree is a vector whose first element is a keyword, and where the other elements are whatever children or attributes nodes of that type (indicated by the keyword) have. Keywords are a type of value that essentially just exists and can only be compared for equality.
This suggests a natural form for our monadic values. Clojure does not define
types, so we need to look at a concrete example; let's pick this one from part
two of this series, in the context of a Stack
monad. In Haskell:
do
a <- pop
b <- pop
push (a + b)
which in our "structural monads" approach from part three turns into:
Bind Pop (\a -> Bind Pop (\b -> Push (a + b)))
A direct equivalent in Clojure hiccup-like notation would look like:
[:bind [:pop] (fn [a] [:bind [:pop] (fn [b] [:push (+ a b)])])]
Or, with a little bit of whitespace:
[:bind [:pop]
(fn [a]
[:bind [:pop]
(fn [b]
[:push (+ a b)]
)])]
This could work, and be about as useable as the Java code was. So that's the
production side. On the consuming side, i.e. the run
method, we can switch on
the keywords, then destructure once we know how many arguments we expect:
(defn run [stack ma]
(case (first ma)
:pure (let [[_ a] ma]
[stack a])
:bind (let [[_ ma f] ma]
(let [[stack a] (run ma)]
(run stack (f a))))
:pop [(pop stack) (peek stack)]
:push (let [[_ a] ma]
[(conj stack a), nil])))
This can work, but there is a lot of extra noise on both sides. Highly regular noise, which means we can exploit the fact that we're using a lisp and make up a macro. Or, rather, here, two macros. We'll keep both of them very simple, with little to no error handling and definitely no bells and whistles.
First, we can make a macro that creates the intermingled nested vectors-and-functions that represent a monadic value. The pattern here is very simple, all we need is a name to bind to and an expression that will become the body of the function. The following macro:
(defmacro mdo
[bindings]
(let [[n v & r] bindings]
(if (empty? r)
v
[:bind v `(fn [~n] (mdo ~r))])))
will turn
(mdo [a [:pop]
b [:pop]
_ [:push (+ a b)]])
into:
[:bind [:pop] (fn [a] [:bind [:pop] (fn [b] [:push (+ a b)])])]
It's not quite as nice as the Haskell notation, and doesn't support the
"non-binding" syntax (we have to explicitly bind _
), but it's very small and
simple, and that counts for a lot with macros, as debugging them can be a bit
of a pain.
On the other side, we have the destructuring. Clojure does not have pattern matching, which forces us to first test for the type of node, and then destructure it. We can, however, (syntactically) reorder these things using another simple macro:
(defmacro match
[expr & cases]
(let [e (gensym)]
`(let [~e ~expr]
(case (first ~e)
~@(->> (partition 2 cases)
(mapcat (fn [[pat body]]
[(first pat) `(let [~(vec (cons '_ (rest pat))) ~e]
~body)])))))))
It's a little bit less simple, but it's not all that complex either. It will transform this:
(defn run [stack ma]
(match ma
[:return a] [stack a]
[:bind ma f] (let [[stack a] (run stack ma)]
(run stack (f a)))
[:pop] [(pop stack) (peek stack)]
[:push a] [(conj stack a) nil]))
into (roughly):
(defn run [stack ma]
(let [G__2036 ma]
(case (first G__2036)
:return (let [[_ a] G__2036]
[stack a])
:bind (let [[_ ma f] G__2036]
(let [[stack a] (run stack ma)]
(run stack (f a))))
:pop (let [[_] G__2036]
[(pop stack) (peek stack)])
:push (let [[_ a] G__2036]
[(conj stack a) nil]))))
The outer let
is there as an optimization (assuming pure code) to ensure that
the expression we match on is evaluated once, and not twice.
Note that, because Clojure is dynamically typed, both of these macros will work
equally well for any monadic structure. (match
will actually work for any
hiccup-like tree.) However, when writing the run
function, you still have to
know which monad you're targeting as you need to cover all the relevant cases.
Not having to convince the compiler about which type a value has doesn't mean
your code will magically work if you give it the wrong type at runtime.
Monadic compiler
So let's get back to our compiler. What monadic operations do we need? First,
let's get the ret
option out of the way: this is not part of the state we
want to track with the monad; this is part of the state we want to track in the
loop itself. So we'll make it an explicit argument to the recursive calls.
In the monadic state, we want to keep the code we're producing. We're obviously
going to need :pure
and :bind
(I chose to avoid calling the monadic no-op
:return
to avoid confusion with the :return
in the input ast). What other
operations do we need?
- We need to be able to
:emit
new code, taking a single opcode and returning the register we may have written to. This last bit should change at some point, but we're not changing this yet: this is a refactoring, we're not changing behaviour. - We need to be able to query the state for our
:current-position
, so we can:emit
code that jumps to it. - We need to get a
:free-register
. - Perhaps less obviously, we need a way to insert a placeholder, blank opcode
now in order to
:resolve
it in the:future
, such as when we are done compiling the body of a:while
loop and need to go back and write down the:jump
instruction to skip it.
This lead me to the following run
function:
(defn run [s ma]
(match ma
[:pure a] [s a]
[:bind ma f] (let [[s a] (run s ma)]
(run s (f a)))
[:current-position] [s (-> s :code count dec)]
[:free-register] [s (next-reg s)]
[:emit code] [(update s :code conj code) (next-reg s)]
[:future] [(-> s
(update :code conj :placeholder)
(update :nested conj (-> s :code count)))
nil]
[:resolve f] [(-> s
(update :code assoc (-> s :nested peek) (f (-> s :code count)))
(update :nested pop))
nil]))
which should be run with an initial state of {:nested (), :code []}
. Using
this, the loop can be simplified rather mechanically to:
(fn h [op & [ret]]
(match op
[:return a] (mdo [r (h a)
_ [:emit [:return r]]])
[:lit v] (mdo [r (if ret
[:pure ret]
[:free-register])
_ [:emit [:load r v]]])
[:set v exp] (if (#{:lit :add :not=} (first exp))
(h exp v)
(mdo [r (h exp)
_ [:emit [:loadr v r]]]))
[:do & exps] (if (empty? exps)
[:pure nil]
(mdo [_ (h (first exps))
_ (h (cons :do (rest exps)))]))
[:while con bod] (mdo [before-condition [:current-position]
condition (h con)
_ [:future]
body (h bod)
_ [:emit [:jump (inc before-condition)]]
_ [:resolve (fn [pos] [:jump-if-zero condition pos])]])
[:var r] [:pure r]
[:add e1 e2] (mdo [left (h e1)
right (h e2)
r (if ret
[:pure ret]
[:free-register])
_ [:emit [:add r left right]]])
[:not= e1 e2] (mdo [left (h arg1)
right (h arg2)
r (if ret
[:pure ret]
[:free-register])
_ [:emit [:not= r left right]]])))
There is definitely still some non-trivial logic in there, but I found it a lot
easier to work with. Let's take another look at that :while
case:
[:while con bod] (mdo [before-condition [:current-position]
condition (h con)
_ [:future]
body (h bod)
_ [:emit [:jump (inc before-condition)]]
_ [:resolve (fn [pos] [:jump-if-zero condition pos])]])
Assuming one understands the semantics of the monadic operations involved, this is a lot more intention-revealing than the previous form.
Conclusion
This concludes my series of blog posts on monads, for now at least. There is one more post I want to add, but it will come much later as I first need to learn the subject matter myself.
If you've read it all, thanks. I hope it's been useful, or at least mildly entertaining, to watch me struggle through it. I also hope it's encouraged you to give monads a try, even if you're not writing in Haskell.