18 April 2021

Monads, part one: What's a monad?

If you've read part zero, you know that the question "What is a monad?" is not a very good one, because "monad" is not a thing, it's a set of properties. A better question would be "what are the properties of a thing that make it a monad?". This is the question I will answer in this post.

It will probably be a fairly unsatisfying answer. By the time you reach the end of this post, you may be left wondering "Ok, so what? Why would I ever want to do this? It looks like so much unnecessary extra complexity." Those questions, specifically:

  • Why should you use monads?
  • When should you use monads?
  • How do you make a monad?

will be addressed in future posts.

From my previous post, we know that in order for a type to be a monad, we need to define two functions operating on this type, namely >>= ("bind") and return. Their types are defined by:

class Monad m where
    (>>=) :: m a -> (a -> m b) -> m b
    return :: a -> m a

and their protocol is defined by:

return a >>= h  === h a
m >>= return    === m
(m >>= g) >>= h === m >>= (\x -> g x >>= h)

What does it mean? If you're unfamiliar with Haskell notation, it may be helpful to translate those to Java. Let's start with the Monad interface. A first attempt could look like:

import java.util.function.Function;

public interface Monad<A> {
    <B> Monad<B> bind(Function<A, Monad<B>> f);
    Monad<A> pure(A a);
}

We've changed the names: >>= is not a valid identifier in Java, so we replaced it with its pronunciantion, while return is a reserved keyword, so we used pure instead. This is not entirely arbitrary: the Haskell standard library defines pure as an alias to return.

This does not look quite right, though. The bind function uses the object itself as its first argument, such that the Haskell ma >>= f would correspond to ma.bind(f), but what does pure a translate to? In the Haskell version, we don't have a monadic value to start with. The pure function should be a constructor, and constructors are not part of Java interfaces. We could work around that by defining two interfaces:

import java.util.function.Function;

public interface Monad {
    <A> Monadic<A> pure(A a);
    interface Monadic<A> {
        <B> Monadic<B> bind(Function<A, Monadic<B>> f);
    }
}

but that gets tedious pretty quickly and doesn't offer much benefit. Let's just leave pure out and assume that our types have constructors. The interface becomes:

import java.util.function.Function;

public interface Monad<A> {
    <B> Monad<B> bind(Function<A, Monad<B>> f);
}

Now, let us turn to the protocol. In pseudo-Java, it could read:

pure(a).bind(h)      === h.apply(a)
m.bind(a -> pure(a)) === m
m.bind(g).bind(h)    === m.bind(x -> g.apply(x).bind(h))

where pure is a stand-in for a constructor.

So what does this all mean? If pure/return is just a constructor, the core meaning of "being a monad" must be in the bind method. Looking at the protocol, the first line tells us that if we have a value "wrapped" in the monad, we can "send a function into" the monad and apply it on the value. The second line tells us that "sending" the constructor "into" the monad is a no-op. And the third line tells us that, if we have two functions we want to "send into" a monad, we can either send them one at a time, or first construct a combination and send that, and the result will be the same.

At this point, it's probably still pretty nebulous. It has something to do with functions and function composition, but let's take a look at a concrete example to hopefully make all of that clearer.

One of the simplest monadic types is Maybe. It represents a potential value. Java uses null by default to represent the absence of a value, but let us assume that we don't like that approach and we want to be a bit more explicit. Maybe could be defined as:

public interface Maybe<A> {
    A get();
    static <A> Maybe<A> just(A a) {
        return new Maybe<A>() {
            public A get() {
                return a;
            }
        };
    }
    static <A> Maybe<A> none() {
        return new Maybe<A>() {
            public A get() {
                return null;
            }
        };
    }
}

That's not very useful, though. If the only operation is get, we haven't gained anything. We could make explicit subclasses for Just and None, and that would be marginally better, because we could now call instanceof on them. However, turning:

if (null == v) {
    ...
} else {
    ...
}

into

if (v instanceof None) {
    ...
} else {
    ...
}

is hardly a win. If every operation has to check for null, checking for instanceof None is not better. Now, let's imagine that it is ok for a sequence of operations to return null if any operation in the chain returns null. They will all have code that looks like:

if (null == v) { return null; }
else {
    ...
}

Now, switching to Maybe can actually help, because we have a fixed implementation for what to do when the argument is None, which means we can make use of polymorphism. If we redefine our Maybe interface to be:

import java.util.function.Function;

public interface Maybe<A> {
    A get();

    <B> Maybe<B> bind(Function<A, Maybe<B>> f);
    static <A> Maybe<A> just(A a) { return new Impl.Just<>(a); }
    static <A> Maybe<A> none() { return new Impl.None<>(); }

    static class Impl {
        private Impl() {}
        private static class Just<A> implements Maybe<A> {
            private final A a;
            private Just(A a) {
                this.a = a;
            }
            public A get() { return a; }
            public <B> Maybe<B> bind(Function<A, Maybe<B>> f) {
                return f.apply(a);
            }
        }
        private static class None<A> implements Maybe<A> {
            private None() {}
            public A get() { return null; }
            public <B> Maybe<B> bind(Function<A, Maybe<B>> f) {
                return new None<>();
            }
        }
    }
}

we can now use it to avoid some of that repetition. Let's take an example from arithmetic for simplicity. Some operations are safe (subtraction, addition, multiplication), but some are not, such as square root or division. We could define "safe" functions that return a Maybe, and write:

public class Arith {

    public static Maybe<Double> add(Double a, Double b) {
        return Maybe.just(a + b);
    }

    public static Maybe<Double> sub(Double a, Double b) {
        return Maybe.just(a - b);
    }

    public static Maybe<Double> mul(Double a, Double b) {
        return Maybe.just(a * b);
    }

    public static Maybe<Double> div(Double a, Double b) {
        return b == 0.0 ? Maybe.none() : Maybe.just(a / b);
    }

    public static Maybe<Double> sqrt(Double a) {
        return a < 0.0 ? Maybe.none() : Maybe.just(Math.sqrt(a));
    }

    public static Maybe<Double> sqr_inv(Double a) {
        return Maybe.just(a)
            .bind(x -> sqrt(x))
            .bind(x -> div(1.0, x));
    }

    public static void main(String[] args) {
        Double d =
            Maybe.just(5.0)
            .bind(r -> add(r, 7.0))
            .bind(r -> sub(r, 2.0))
            .bind(r -> sqr_inv(r))
            .bind(r -> mul(r, 3.0))
            .get();
        System.out.println(d);
    }
}

where none of the arithmetic functions had to check whether its argument was null.

Note that we did not make Maybe implement Monad. This is because, in order for Maybe to work as expected, its bind method really has to return an instance of Maybe, not a generic Monad. (Remember, "monad" is not a thing, just a set of properties; in this case, the thing is Maybe.) This is not specific to Maybe and will be true of any monadic type: bind needs to return the monadic type, not a generic Monad instance. To the best of my knowledge, there is no way to express that property on a Java interface, which leaves us in a bit of an awkward place in terms of expressing monads in Java: return cannot be part of the interface because it is a constructor, and bind has a type that we cannot express generically.

The conclusion is simple: there is no way to express the concept of a monad in terms of Java types. This does not, however, mean that we cannot use monads in Java: we have just demonstrated the use of Maybe as a monad.

To prove (or verify) that our definition is indeed a monad, we can take another look at the protocol, a.k.a. "the monad laws". The first law states that:

pure(a).bind(h) === h.apply(a)

In the case of our Maybe monad, pure is just. Looking at the code, we can see that just(a).bind(h) will expand directly to h.apply(a), so that one checks out.

The second law states that:

m.bind(a -> pure(a)) === m

where m is either a Just or a None. If m is of type None, we know that none().bind(f) is none() for any f (for any reasonable definition of .equals1). If m is just(x), then just(x).bind(a -> just(a)) is (a -> just(a)).apply(x) which is just(x).

Finally, the third law states that:

m.bind(g).bind(h) === m.bind(x -> g.apply(x).bind(h))

This requires a bit more of a case analysis as None may pop up at any stage.

  • If m is none() to start with, the case is pretty simple. On the left, none().bind(g).bind(h) evaluates to none().bind(h) which evaluates to none(), while on the right, none().bind(x -> g.apply(x).bind(h)) evaluates directly to none().
  • If m is just(a), then m.bind(g) will evaluate to g.apply(a), and the left side is g.apply(a).bind(h). On the right side, just(a).bind(x -> g.apply(x).bind(h)) evaluates to (x -> g.apply(x).bind(h)).apply(a) which evaluates to g.apply(a).bind(h).

So we have one example of a thing that happens to have the monadic properties. Let's try and explain monads in a slightly more precise way now. A monadic type is a type that can carry out monadic computations, i.e. functions that take in "naked" arguments and return values "wrapped" in the monadic type.

The monad laws ensure that you can compose monadic computations; in the Arith example above, the third monadic law allows us to combine sqrt and div into sqr_inv.

So that's it, that's all "a monad" is: a type Type<A> that has a single-argument constructor and a bind function with the signature

<B> Type<B> bind(Function<A, Type<B>> f);

Note that the monad abstraction does not define any way to "get out". You can send computations in by calling bind, but how you get anything out is defined by the thing, not by its monadic properties. In this case we defined a simple get method on Maybe that returns either a value or null.

What does it mean to carry a computation? Well, it can mean a lot of things, really. For example, let us next define a monad that keeps track of all the values it has seen.

import java.util.function.Function;
import java.util.List;
import java.util.ArrayList;

public interface Logger<A> {
    <B> Logger<B> bind(Function<A, Logger<B>> f);
    static <A> Logger<A> pure(A a) {
        return new Impl.Wrapped<>(a);
    }

    A getValue();
    List<String> getLog();

    static class Impl {
        private Impl() {}
        private static class Wrapped<A> implements Logger<A> {
            public final A a;
            public final List<String> log;
            public Wrapped(A a) {
                this.a = a;
                this.log = new ArrayList<>();
                this.log.add(a.toString());
            }
            public Wrapped(A a, List<String> log) {
                this.a = a;
                this.log = log;
            }
            public <B> Logger<B> bind(Function<A, Logger<B>> f) {
                Logger<B> mb = f.apply(a);
                List<String> new_log = new ArrayList<>(log);
                new_log.addAll(mb.getLog());
                return new Wrapped<>(mb.getValue(), new_log);
            }
            public A getValue() {
                return a;
            }
            public List<String> getLog() {
                return new ArrayList<>(log);
            }
        }
    }

    static <A> Logger<A> _return(A a) {
        return pure(a);
    }

    public static void main(String[] args) {
        Logger<Integer> log1 = pure(1)
            .bind(a -> pure(a + 3))
            .bind(b -> pure(b + 4))
            .bind(c -> pure(c * 3));
        Logger<Integer> log2 = pure(1)
            .bind(a -> pure(a + 3)
            .bind(b -> pure(b + 4)
            .bind(c -> pure(c * 3))));
        Logger<Integer> log3 =
            pure(1)    .bind(a ->
            pure(a + 3).bind(b ->
            pure(b + 4).bind(c ->
            pure(c * 3))));
        Logger<Integer> log4 =
            pure    (    1)    .bind(a ->
            pure    (a + 3).bind(b ->
            pure    (b + 4).bind(c ->
            _return (c * b + a))));
        System.out.println(log1.getLog());
        System.out.println(log2.getLog());
        System.out.println(log3.getLog());
        System.out.println(log4.getLog());
    }
}

If you run this code, you'll see that it prints:

[1, 4, 8, 24]
[1, 4, 8, 24]
[1, 4, 8, 24]
[1, 4, 8, 33]

indicating that the first three monadic values log1, log2 and log3 are equivalent. Looking more closely, this should come as no surprise: log2 can be constructed from log1 by applying the third monadic law, whereas log3 is really just log2 with weird indentation and line breaks.

In the format used by log3, one can think of the line:

            pure(a + 3).bind(b ->

as meaning "execute the pure (i.e. non-monadic) computation a + 3, and bind its result to the variable name b". Hopefully this gives a bit more meaning to the names chosen for pure and bind.

Finally, in log4, we show that this notion of binding variables is not just for show and we can actually use these bindings in subsequent steps. We also renamed the last pure to _return, hopefully again illustrating why that name makes sense for that function.

There you have it. A monad is a type that defines a bind method with well-defined properties, and these properties allow you to use bind to construct computations. Not impressed yet? I'll leave you with a final example of implementing green threads in about 60 lines of Java.

import java.util.function.Function;
import java.util.function.Supplier;
import java.util.List;
import java.util.LinkedList;
import java.util.Queue;

public interface Fiber<A> {
    <B> Fiber<B> bind(Function<A, Fiber<B>> f);
    <B> B ifReturn(Function<A, B> f, Supplier<B> s);
    Fiber<A> step();

    static class Impl {
        private Impl() {}
        private final static class Return<A> implements Fiber<A> {
            public final A a;
            public Return(A a) {
                this.a = a;
            }
            public <B> Fiber<B> bind(Function<A, Fiber<B>> f) {
                return new Bind<A, B>(this, f);
            }
            public <B> B ifReturn(
                    Function<A, B> f,
                    Supplier<B> s) {
                return f.apply(a);
            }
            public Fiber<A> step() {
                return this;
            }
        }
        private final static class Bind<A, B> implements Fiber<B> {
            public final Fiber<A> ma;
            public final Function<A, Fiber<B>> f;
            public Bind(Fiber<A> ma, Function<A, Fiber<B>> f) {
                this.ma = ma;
                this.f = f;
            }
            public <C> Fiber<C> bind(Function<B, Fiber<C>> f) {
                return new Bind<>(this, f);
            }
            public <C> C ifReturn(
                    Function<B, C> f,
                    Supplier<C> s) {
                return s.get();
            }
            public Fiber<B> step() {
                return ma.ifReturn(f, () -> new Bind<A, B>(ma.step(), f));
            }
        }
    }

    static <A> Fiber<A> pure(A a) {
        return new Impl.Return<A>(a);
    }

    static <A> List<A> exec(List<Fiber<A>> ls) {
        List<A> rets = new LinkedList<>();
        Queue<Fiber<A>> q = new LinkedList<>(ls);
        while(!q.isEmpty()) {
            Fiber<A> f = q.remove();
            f.ifReturn(
                a -> rets.add(a),
                () -> q.add(f.step()));
        }
        return rets;
    }

    static <A> Fiber<A> _return(A a) {
        return pure(a);
    }

    static Fiber<Void> debug(String msg) {
        System.out.println(msg);
        return pure(null);
    }


    public static void main(String[] args) {
        Fiber<Double> f1 =
            pure          (5) .bind(a ->
            debug("thread 1") .bind(   _1 ->
            pure      (a + 3) .bind(b ->
            debug("thread 1") .bind(   _2 ->
            pure      (b + 5) .bind(c ->
            debug("thread 1") .bind(   _3 ->
            pure      (b * c) .bind(i ->
            debug("thread 1") .bind(   _4 ->
            _return (1.0 * i)
            ))))))));
        Fiber<Double> f2 =
            pure                (1) .bind(a ->
            debug("thread 2") .bind(    _1 ->
            pure                (2) .bind(b ->
            debug("thread 2") .bind(    _2 ->
            pure                (1) .bind(c ->
            debug("thread 2") .bind(    _3 ->
            pure            (b * b) .bind(b2 ->
            debug("thread 2") .bind(    _4 ->
            pure            (a * c) .bind(ac ->
            debug("thread 2") .bind(    _5 ->
            pure      (b2 - 4 * ac) .bind(delta ->
            debug("thread 2") .bind(    _6 ->
            pure (Math.sqrt(delta)) .bind(rac ->
            debug("thread 2") .bind(    _7 ->
            _return ((-b + rac) / (2 * a))
            ))))))))))))));

        System.out.println(exec(new LinkedList<Fiber<Double>>() {{ add(f1); add(f2); }}));
        System.out.println(exec(new LinkedList<Fiber<Double>>() {{ add(f2); add(f1); }}));
    }
}

This is obviously highly non-idiomatic Java code. If you're anything like me when I was first learning about monads, at this point you're probably feeling like you understand what a monad is, and it seems about as cool and practical as the Y combinator.

This post is long enough as it is, but rest assured that I will try to convince you monads can be a useful practical tool in a future one.


  1. Left as an exercise for the reader.

Tags: monad-tutorial