Monads, part one: What's a monad?
If you've read part zero, you know that the question "What is a monad?" is not a very good one, because "monad" is not a thing, it's a set of properties. A better question would be "what are the properties of a thing that make it a monad?". This is the question I will answer in this post.
It will probably be a fairly unsatisfying answer. By the time you reach the end of this post, you may be left wondering "Ok, so what? Why would I ever want to do this? It looks like so much unnecessary extra complexity." Those questions, specifically:
- Why should you use monads?
- When should you use monads?
- How do you make a monad?
will be addressed in future posts.
From my previous post, we know that in order for a type to be a
monad, we need to define two functions operating on this type, namely >>=
("bind") and return
. Their types are defined by:
class Monad m where
(>>=) :: m a -> (a -> m b) -> m b
return :: a -> m a
and their protocol is defined by:
return a >>= h === h a
m >>= return === m
(m >>= g) >>= h === m >>= (\x -> g x >>= h)
What does it mean? If you're unfamiliar with Haskell notation, it may be
helpful to translate those to Java. Let's start with the Monad
interface. A
first attempt could look like:
import java.util.function.Function;
public interface Monad<A> {
<B> Monad<B> bind(Function<A, Monad<B>> f);
Monad<A> pure(A a);
}
We've changed the names: >>=
is not a valid identifier in Java, so we
replaced it with its pronunciantion, while return
is a reserved keyword,
so we used pure
instead. This is not entirely arbitrary: the Haskell standard
library defines pure
as an alias to return
.
This does not look quite right, though. The bind
function uses the object
itself as its first argument, such that the Haskell ma >>= f
would correspond
to ma.bind(f)
, but what does pure a
translate to? In the Haskell version,
we don't have a monadic value to start with. The pure
function should be a
constructor, and constructors are not part of Java interfaces. We could work
around that by defining two interfaces:
import java.util.function.Function;
public interface Monad {
<A> Monadic<A> pure(A a);
interface Monadic<A> {
<B> Monadic<B> bind(Function<A, Monadic<B>> f);
}
}
but that gets tedious pretty quickly and doesn't offer much benefit. Let's just
leave pure
out and assume that our types have constructors. The interface
becomes:
import java.util.function.Function;
public interface Monad<A> {
<B> Monad<B> bind(Function<A, Monad<B>> f);
}
Now, let us turn to the protocol. In pseudo-Java, it could read:
pure(a).bind(h) === h.apply(a)
m.bind(a -> pure(a)) === m
m.bind(g).bind(h) === m.bind(x -> g.apply(x).bind(h))
where pure
is a stand-in for a constructor.
So what does this all mean? If pure
/return
is just a constructor, the core
meaning of "being a monad" must be in the bind
method. Looking at the
protocol, the first line tells us that if we have a value "wrapped" in the
monad, we can "send a function into" the monad and apply it on the value. The
second line tells us that "sending" the constructor "into" the monad is a
no-op. And the third line tells us that, if we have two functions we want to
"send into" a monad, we can either send them one at a time, or first construct
a combination and send that, and the result will be the same.
At this point, it's probably still pretty nebulous. It has something to do with functions and function composition, but let's take a look at a concrete example to hopefully make all of that clearer.
One of the simplest monadic types is Maybe
. It represents a potential value.
Java uses null
by default to represent the absence of a value, but let us
assume that we don't like that approach and we want to be a bit more explicit.
Maybe
could be defined as:
public interface Maybe<A> {
A get();
static <A> Maybe<A> just(A a) {
return new Maybe<A>() {
public A get() {
return a;
}
};
}
static <A> Maybe<A> none() {
return new Maybe<A>() {
public A get() {
return null;
}
};
}
}
That's not very useful, though. If the only operation is get
, we haven't
gained anything. We could make explicit subclasses for Just
and None
, and
that would be marginally better, because we could now call instanceof
on
them. However, turning:
if (null == v) {
...
} else {
...
}
into
if (v instanceof None) {
...
} else {
...
}
is hardly a win. If every operation has to check for null
, checking for
instanceof None
is not better. Now, let's imagine that it is ok for a
sequence of operations to return null
if any operation in the chain returns
null
. They will all have code that looks like:
if (null == v) { return null; }
else {
...
}
Now, switching to Maybe
can actually help, because we have a fixed
implementation for what to do when the argument is None
, which means we can
make use of polymorphism. If we redefine our Maybe
interface to be:
import java.util.function.Function;
public interface Maybe<A> {
A get();
<B> Maybe<B> bind(Function<A, Maybe<B>> f);
static <A> Maybe<A> just(A a) { return new Impl.Just<>(a); }
static <A> Maybe<A> none() { return new Impl.None<>(); }
static class Impl {
private Impl() {}
private static class Just<A> implements Maybe<A> {
private final A a;
private Just(A a) {
this.a = a;
}
public A get() { return a; }
public <B> Maybe<B> bind(Function<A, Maybe<B>> f) {
return f.apply(a);
}
}
private static class None<A> implements Maybe<A> {
private None() {}
public A get() { return null; }
public <B> Maybe<B> bind(Function<A, Maybe<B>> f) {
return new None<>();
}
}
}
}
we can now use it to avoid some of that repetition. Let's take an example from
arithmetic for simplicity. Some operations are safe (subtraction, addition,
multiplication), but some are not, such as square root or division. We could
define "safe" functions that return a Maybe
, and write:
public class Arith {
public static Maybe<Double> add(Double a, Double b) {
return Maybe.just(a + b);
}
public static Maybe<Double> sub(Double a, Double b) {
return Maybe.just(a - b);
}
public static Maybe<Double> mul(Double a, Double b) {
return Maybe.just(a * b);
}
public static Maybe<Double> div(Double a, Double b) {
return b == 0.0 ? Maybe.none() : Maybe.just(a / b);
}
public static Maybe<Double> sqrt(Double a) {
return a < 0.0 ? Maybe.none() : Maybe.just(Math.sqrt(a));
}
public static Maybe<Double> sqr_inv(Double a) {
return Maybe.just(a)
.bind(x -> sqrt(x))
.bind(x -> div(1.0, x));
}
public static void main(String[] args) {
Double d =
Maybe.just(5.0)
.bind(r -> add(r, 7.0))
.bind(r -> sub(r, 2.0))
.bind(r -> sqr_inv(r))
.bind(r -> mul(r, 3.0))
.get();
System.out.println(d);
}
}
where none of the arithmetic functions had to check whether its argument was
null
.
Note that we did not make Maybe
implement Monad
. This is because, in order
for Maybe
to work as expected, its bind
method really has to return an
instance of Maybe
, not a generic Monad
. (Remember, "monad" is not a thing,
just a set of properties; in this case, the thing is Maybe
.) This is not
specific to Maybe
and will be true of any monadic type: bind
needs to
return the monadic type, not a generic Monad
instance. To the best of my
knowledge, there is no way to express that property on a Java interface, which
leaves us in a bit of an awkward place in terms of expressing monads in Java:
return
cannot be part of the interface because it is a constructor, and
bind
has a type that we cannot express generically.
The conclusion is simple: there is no way to express the concept of a monad in
terms of Java types. This does not, however, mean that we cannot use monads
in Java: we have just demonstrated the use of Maybe
as a monad.
To prove (or verify) that our definition is indeed a monad, we can take another look at the protocol, a.k.a. "the monad laws". The first law states that:
pure(a).bind(h) === h.apply(a)
In the case of our Maybe
monad, pure
is just
. Looking at the code, we can
see that just(a).bind(h)
will expand directly to h.apply(a)
, so that one
checks out.
The second law states that:
m.bind(a -> pure(a)) === m
where m
is either a Just
or a None
. If m
is of type None
, we know
that none().bind(f)
is none()
for any f
(for any reasonable definition of
.equals
1). If m
is just(x)
, then just(x).bind(a -> just(a))
is (a -> just(a)).apply(x)
which is just(x)
.
Finally, the third law states that:
m.bind(g).bind(h) === m.bind(x -> g.apply(x).bind(h))
This requires a bit more of a case analysis as None
may pop up at any stage.
- If
m
isnone()
to start with, the case is pretty simple. On the left,none().bind(g).bind(h)
evaluates tonone().bind(h)
which evaluates tonone()
, while on the right,none().bind(x -> g.apply(x).bind(h))
evaluates directly tonone()
. - If
m
isjust(a)
, thenm.bind(g)
will evaluate tog.apply(a)
, and the left side isg.apply(a).bind(h)
. On the right side,just(a).bind(x -> g.apply(x).bind(h))
evaluates to(x -> g.apply(x).bind(h)).apply(a)
which evaluates tog.apply(a).bind(h)
.
So we have one example of a thing that happens to have the monadic properties. Let's try and explain monads in a slightly more precise way now. A monadic type is a type that can carry out monadic computations, i.e. functions that take in "naked" arguments and return values "wrapped" in the monadic type.
The monad laws ensure that you can compose monadic computations; in the Arith
example above, the third monadic law allows us to combine sqrt
and div
into
sqr_inv
.
So that's it, that's all "a monad" is: a type Type<A>
that has a
single-argument constructor and a bind
function with the signature
<B> Type<B> bind(Function<A, Type<B>> f);
Note that the monad abstraction does not define any way to "get out". You can
send computations in by calling bind
, but how you get anything out is defined
by the thing, not by its monadic properties. In this case we defined a simple
get
method on Maybe
that returns either a value or null
.
What does it mean to carry a computation? Well, it can mean a lot of things, really. For example, let us next define a monad that keeps track of all the values it has seen.
import java.util.function.Function;
import java.util.List;
import java.util.ArrayList;
public interface Logger<A> {
<B> Logger<B> bind(Function<A, Logger<B>> f);
static <A> Logger<A> pure(A a) {
return new Impl.Wrapped<>(a);
}
A getValue();
List<String> getLog();
static class Impl {
private Impl() {}
private static class Wrapped<A> implements Logger<A> {
public final A a;
public final List<String> log;
public Wrapped(A a) {
this.a = a;
this.log = new ArrayList<>();
this.log.add(a.toString());
}
public Wrapped(A a, List<String> log) {
this.a = a;
this.log = log;
}
public <B> Logger<B> bind(Function<A, Logger<B>> f) {
Logger<B> mb = f.apply(a);
List<String> new_log = new ArrayList<>(log);
new_log.addAll(mb.getLog());
return new Wrapped<>(mb.getValue(), new_log);
}
public A getValue() {
return a;
}
public List<String> getLog() {
return new ArrayList<>(log);
}
}
}
static <A> Logger<A> _return(A a) {
return pure(a);
}
public static void main(String[] args) {
Logger<Integer> log1 = pure(1)
.bind(a -> pure(a + 3))
.bind(b -> pure(b + 4))
.bind(c -> pure(c * 3));
Logger<Integer> log2 = pure(1)
.bind(a -> pure(a + 3)
.bind(b -> pure(b + 4)
.bind(c -> pure(c * 3))));
Logger<Integer> log3 =
pure(1) .bind(a ->
pure(a + 3).bind(b ->
pure(b + 4).bind(c ->
pure(c * 3))));
Logger<Integer> log4 =
pure ( 1) .bind(a ->
pure (a + 3).bind(b ->
pure (b + 4).bind(c ->
_return (c * b + a))));
System.out.println(log1.getLog());
System.out.println(log2.getLog());
System.out.println(log3.getLog());
System.out.println(log4.getLog());
}
}
If you run this code, you'll see that it prints:
[1, 4, 8, 24]
[1, 4, 8, 24]
[1, 4, 8, 24]
[1, 4, 8, 33]
indicating that the first three monadic values log1
, log2
and log3
are
equivalent. Looking more closely, this should come as no surprise: log2
can
be constructed from log1
by applying the third monadic law, whereas log3
is
really just log2
with weird indentation and line breaks.
In the format used by log3
, one can think of the line:
pure(a + 3).bind(b ->
as meaning "execute the pure (i.e. non-monadic) computation a + 3
, and
bind its result to the variable name b". Hopefully this gives a bit more
meaning to the names chosen for pure
and bind
.
Finally, in log4
, we show that this notion of binding variables is not just
for show and we can actually use these bindings in subsequent steps. We also
renamed the last pure
to _return
, hopefully again illustrating why that
name makes sense for that function.
There you have it. A monad is a type that defines a bind
method with
well-defined properties, and these properties allow you to use bind
to
construct computations. Not impressed yet? I'll leave you with a final example
of implementing green threads in about 60 lines of Java.
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.List;
import java.util.LinkedList;
import java.util.Queue;
public interface Fiber<A> {
<B> Fiber<B> bind(Function<A, Fiber<B>> f);
<B> B ifReturn(Function<A, B> f, Supplier<B> s);
Fiber<A> step();
static class Impl {
private Impl() {}
private final static class Return<A> implements Fiber<A> {
public final A a;
public Return(A a) {
this.a = a;
}
public <B> Fiber<B> bind(Function<A, Fiber<B>> f) {
return new Bind<A, B>(this, f);
}
public <B> B ifReturn(
Function<A, B> f,
Supplier<B> s) {
return f.apply(a);
}
public Fiber<A> step() {
return this;
}
}
private final static class Bind<A, B> implements Fiber<B> {
public final Fiber<A> ma;
public final Function<A, Fiber<B>> f;
public Bind(Fiber<A> ma, Function<A, Fiber<B>> f) {
this.ma = ma;
this.f = f;
}
public <C> Fiber<C> bind(Function<B, Fiber<C>> f) {
return new Bind<>(this, f);
}
public <C> C ifReturn(
Function<B, C> f,
Supplier<C> s) {
return s.get();
}
public Fiber<B> step() {
return ma.ifReturn(f, () -> new Bind<A, B>(ma.step(), f));
}
}
}
static <A> Fiber<A> pure(A a) {
return new Impl.Return<A>(a);
}
static <A> List<A> exec(List<Fiber<A>> ls) {
List<A> rets = new LinkedList<>();
Queue<Fiber<A>> q = new LinkedList<>(ls);
while(!q.isEmpty()) {
Fiber<A> f = q.remove();
f.ifReturn(
a -> rets.add(a),
() -> q.add(f.step()));
}
return rets;
}
static <A> Fiber<A> _return(A a) {
return pure(a);
}
static Fiber<Void> debug(String msg) {
System.out.println(msg);
return pure(null);
}
public static void main(String[] args) {
Fiber<Double> f1 =
pure (5) .bind(a ->
debug("thread 1") .bind( _1 ->
pure (a + 3) .bind(b ->
debug("thread 1") .bind( _2 ->
pure (b + 5) .bind(c ->
debug("thread 1") .bind( _3 ->
pure (b * c) .bind(i ->
debug("thread 1") .bind( _4 ->
_return (1.0 * i)
))))))));
Fiber<Double> f2 =
pure (1) .bind(a ->
debug("thread 2") .bind( _1 ->
pure (2) .bind(b ->
debug("thread 2") .bind( _2 ->
pure (1) .bind(c ->
debug("thread 2") .bind( _3 ->
pure (b * b) .bind(b2 ->
debug("thread 2") .bind( _4 ->
pure (a * c) .bind(ac ->
debug("thread 2") .bind( _5 ->
pure (b2 - 4 * ac) .bind(delta ->
debug("thread 2") .bind( _6 ->
pure (Math.sqrt(delta)) .bind(rac ->
debug("thread 2") .bind( _7 ->
_return ((-b + rac) / (2 * a))
))))))))))))));
System.out.println(exec(new LinkedList<Fiber<Double>>() {{ add(f1); add(f2); }}));
System.out.println(exec(new LinkedList<Fiber<Double>>() {{ add(f2); add(f1); }}));
}
}
This is obviously highly non-idiomatic Java code. If you're anything like me when I was first learning about monads, at this point you're probably feeling like you understand what a monad is, and it seems about as cool and practical as the Y combinator.
This post is long enough as it is, but rest assured that I will try to convince you monads can be a useful practical tool in a future one.
Left as an exercise for the reader.
↩