## Monads, part one: ~~What's a monad?~~

If you've read part zero, you know that the question "What is a monad?" is
not a very good one, because "monad" is not a *thing*, it's a set of
properties. A better question would be "what are the properties of a thing that
make it a monad?". This is the question I will answer in this post.

It will probably be a fairly unsatisfying answer. By the time you reach the end of this post, you may be left wondering "Ok, so what? Why would I ever want to do this? It looks like so much unnecessary extra complexity." Those questions, specifically:

- Why should you use monads?
- When should you use monads?
- How do you make a monad?

will be addressed in future posts.

From my previous post, we know that in order for a type to be a
monad, we need to define two functions operating on this type, namely `>>=`

("bind") and `return`

. Their types are defined by:

```
class Monad m where
(>>=) :: m a -> (a -> m b) -> m b
return :: a -> m a
```

and their protocol is defined by:

```
return a >>= h === h a
m >>= return === m
(m >>= g) >>= h === m >>= (\x -> g x >>= h)
```

What does it mean? If you're unfamiliar with Haskell notation, it may be
helpful to translate those to Java. Let's start with the `Monad`

interface. A
first attempt could look like:

```
import java.util.function.Function;
public interface Monad<A> {
<B> Monad<B> bind(Function<A, Monad<B>> f);
Monad<A> pure(A a);
}
```

We've changed the names: `>>=`

is not a valid identifier in Java, so we
replaced it with its pronunciantion, while `return`

is a reserved keyword,
so we used `pure`

instead. This is not entirely arbitrary: the Haskell standard
library defines `pure`

as an alias to `return`

.

This does not look quite right, though. The `bind`

function uses the object
itself as its first argument, such that the Haskell `ma >>= f`

would correspond
to `ma.bind(f)`

, but what does `pure a`

translate to? In the Haskell version,
we don't have a monadic value to start with. The `pure`

function should be a
constructor, and constructors are not part of Java interfaces. We could work
around that by defining two interfaces:

```
import java.util.function.Function;
public interface Monad {
<A> Monadic<A> pure(A a);
interface Monadic<A> {
<B> Monadic<B> bind(Function<A, Monadic<B>> f);
}
}
```

but that gets tedious pretty quickly and doesn't offer much benefit. Let's just
leave `pure`

out and assume that our types have constructors. The interface
becomes:

```
import java.util.function.Function;
public interface Monad<A> {
<B> Monad<B> bind(Function<A, Monad<B>> f);
}
```

Now, let us turn to the protocol. In pseudo-Java, it could read:

```
pure(a).bind(h) === h.apply(a)
m.bind(a -> pure(a)) === m
m.bind(g).bind(h) === m.bind(x -> g.apply(x).bind(h))
```

where `pure`

is a stand-in for a constructor.

So what does this all mean? If `pure`

/`return`

is just a constructor, the core
meaning of "being a monad" must be in the `bind`

method. Looking at the
protocol, the first line tells us that if we have a value "wrapped" in the
monad, we can "send a function into" the monad and apply it on the value. The
second line tells us that "sending" the constructor "into" the monad is a
no-op. And the third line tells us that, if we have two functions we want to
"send into" a monad, we can either send them one at a time, or first construct
a combination and send that, and the result will be the same.

At this point, it's probably still pretty nebulous. It has something to do with functions and function composition, but let's take a look at a concrete example to hopefully make all of that clearer.

One of the simplest monadic types is `Maybe`

. It represents a potential value.
Java uses `null`

by default to represent the absence of a value, but let us
assume that we don't like that approach and we want to be a bit more explicit.
`Maybe`

could be defined as:

```
public interface Maybe<A> {
A get();
static <A> Maybe<A> just(A a) {
return new Maybe<A>() {
public A get() {
return a;
}
};
}
static <A> Maybe<A> none() {
return new Maybe<A>() {
public A get() {
return null;
}
};
}
}
```

That's not very useful, though. If the only operation is `get`

, we haven't
gained anything. We could make explicit subclasses for `Just`

and `None`

, and
that would be marginally better, because we could now call `instanceof`

on
them. However, turning:

```
if (null == v) {
...
} else {
...
}
```

into

```
if (v instanceof None) {
...
} else {
...
}
```

is hardly a win. If every operation has to check for `null`

, checking for
`instanceof None`

is not better. Now, let's imagine that it is ok for a
sequence of operations to return `null`

if any operation in the chain returns
`null`

. They will all have code that looks like:

```
if (null == v) { return null; }
else {
...
}
```

Now, switching to `Maybe`

can actually help, because we have a fixed
implementation for what to do when the argument is `None`

, which means we can
make use of polymorphism. If we redefine our `Maybe`

interface to be:

```
import java.util.function.Function;
public interface Maybe<A> {
A get();
<B> Maybe<B> bind(Function<A, Maybe<B>> f);
static <A> Maybe<A> just(A a) { return new Impl.Just<>(a); }
static <A> Maybe<A> none() { return new Impl.None<>(); }
static class Impl {
private Impl() {}
private static class Just<A> implements Maybe<A> {
private final A a;
private Just(A a) {
this.a = a;
}
public A get() { return a; }
public <B> Maybe<B> bind(Function<A, Maybe<B>> f) {
return f.apply(a);
}
}
private static class None<A> implements Maybe<A> {
private None() {}
public A get() { return null; }
public <B> Maybe<B> bind(Function<A, Maybe<B>> f) {
return new None<>();
}
}
}
}
```

we can now use it to avoid some of that repetition. Let's take an example from
arithmetic for simplicity. Some operations are safe (subtraction, addition,
multiplication), but some are not, such as square root or division. We could
define "safe" functions that return a `Maybe`

, and write:

```
public class Arith {
public static Maybe<Double> add(Double a, Double b) {
return Maybe.just(a + b);
}
public static Maybe<Double> sub(Double a, Double b) {
return Maybe.just(a - b);
}
public static Maybe<Double> mul(Double a, Double b) {
return Maybe.just(a * b);
}
public static Maybe<Double> div(Double a, Double b) {
return b == 0.0 ? Maybe.none() : Maybe.just(a / b);
}
public static Maybe<Double> sqrt(Double a) {
return a < 0.0 ? Maybe.none() : Maybe.just(Math.sqrt(a));
}
public static Maybe<Double> sqr_inv(Double a) {
return Maybe.just(a)
.bind(x -> sqrt(x))
.bind(x -> div(1.0, x));
}
public static void main(String[] args) {
Double d =
Maybe.just(5.0)
.bind(r -> add(r, 7.0))
.bind(r -> sub(r, 2.0))
.bind(r -> sqr_inv(r))
.bind(r -> mul(r, 3.0))
.get();
System.out.println(d);
}
}
```

where none of the arithmetic functions had to check whether its argument was
`null`

.

Note that we did not make `Maybe`

implement `Monad`

. This is because, in order
for `Maybe`

to work as expected, its `bind`

method really has to return an
instance of `Maybe`

, not a generic `Monad`

. (Remember, "monad" is not a thing,
just a set of properties; in this case, the thing is `Maybe`

.) This is not
specific to `Maybe`

and will be true of any monadic type: `bind`

needs to
return the monadic type, not a generic `Monad`

instance. To the best of my
knowledge, there is no way to express that property on a Java interface, which
leaves us in a bit of an awkward place in terms of expressing monads in Java:
`return`

cannot be part of the interface because it is a constructor, and
`bind`

has a type that we cannot express generically.

The conclusion is simple: there is no way to express the concept of a monad in
terms of Java types. This does not, however, mean that we cannot *use* monads
in Java: we have just demonstrated the use of `Maybe`

as a monad.

To prove (or verify) that our definition is indeed a monad, we can take another look at the protocol, a.k.a. "the monad laws". The first law states that:

```
pure(a).bind(h) === h.apply(a)
```

In the case of our `Maybe`

monad, `pure`

is `just`

. Looking at the code, we can
see that `just(a).bind(h)`

will expand directly to `h.apply(a)`

, so that one
checks out.

The second law states that:

```
m.bind(a -> pure(a)) === m
```

where `m`

is either a `Just`

or a `None`

. If `m`

is of type `None`

, we know
that `none().bind(f)`

is `none()`

for any `f`

(for any reasonable definition of
`.equals`

^{1}). If `m`

is `just(x)`

, then `just(x).bind(a -> just(a))`

is `(a -> just(a)).apply(x)`

which is `just(x)`

.

Finally, the third law states that:

```
m.bind(g).bind(h) === m.bind(x -> g.apply(x).bind(h))
```

This requires a bit more of a case analysis as `None`

may pop up at any stage.

- If
`m`

is`none()`

to start with, the case is pretty simple. On the left,`none().bind(g).bind(h)`

evaluates to`none().bind(h)`

which evaluates to`none()`

, while on the right,`none().bind(x -> g.apply(x).bind(h))`

evaluates directly to`none()`

. - If
`m`

is`just(a)`

, then`m.bind(g)`

will evaluate to`g.apply(a)`

, and the left side is`g.apply(a).bind(h)`

. On the right side,`just(a).bind(x -> g.apply(x).bind(h))`

evaluates to`(x -> g.apply(x).bind(h)).apply(a)`

which evaluates to`g.apply(a).bind(h)`

.

So we have one example of a thing that happens to have the monadic properties. Let's try and explain monads in a slightly more precise way now. A monadic type is a type that can carry out monadic computations, i.e. functions that take in "naked" arguments and return values "wrapped" in the monadic type.

The monad laws ensure that you can compose monadic computations; in the `Arith`

example above, the third monadic law allows us to combine `sqrt`

and `div`

into
`sqr_inv`

.

So that's it, that's all "a monad" is: a type `Type<A>`

that has a
single-argument constructor and a `bind`

function with the signature

```
<B> Type<B> bind(Function<A, Type<B>> f);
```

Note that the monad abstraction *does not* define any way to "get out". You can
send computations in by calling `bind`

, but how you get anything out is defined
by the *thing*, not by its monadic properties. In this case we defined a simple
`get`

method on `Maybe`

that returns either a value or `null`

.

What does it mean to carry a computation? Well, it can mean a lot of things, really. For example, let us next define a monad that keeps track of all the values it has seen.

```
import java.util.function.Function;
import java.util.List;
import java.util.ArrayList;
public interface Logger<A> {
<B> Logger<B> bind(Function<A, Logger<B>> f);
static <A> Logger<A> pure(A a) {
return new Impl.Wrapped<>(a);
}
A getValue();
List<String> getLog();
static class Impl {
private Impl() {}
private static class Wrapped<A> implements Logger<A> {
public final A a;
public final List<String> log;
public Wrapped(A a) {
this.a = a;
this.log = new ArrayList<>();
this.log.add(a.toString());
}
public Wrapped(A a, List<String> log) {
this.a = a;
this.log = log;
}
public <B> Logger<B> bind(Function<A, Logger<B>> f) {
Logger<B> mb = f.apply(a);
List<String> new_log = new ArrayList<>(log);
new_log.addAll(mb.getLog());
return new Wrapped<>(mb.getValue(), new_log);
}
public A getValue() {
return a;
}
public List<String> getLog() {
return new ArrayList<>(log);
}
}
}
static <A> Logger<A> _return(A a) {
return pure(a);
}
public static void main(String[] args) {
Logger<Integer> log1 = pure(1)
.bind(a -> pure(a + 3))
.bind(b -> pure(b + 4))
.bind(c -> pure(c * 3));
Logger<Integer> log2 = pure(1)
.bind(a -> pure(a + 3)
.bind(b -> pure(b + 4)
.bind(c -> pure(c * 3))));
Logger<Integer> log3 =
pure(1) .bind(a ->
pure(a + 3).bind(b ->
pure(b + 4).bind(c ->
pure(c * 3))));
Logger<Integer> log4 =
pure ( 1) .bind(a ->
pure (a + 3).bind(b ->
pure (b + 4).bind(c ->
_return (c * b + a))));
System.out.println(log1.getLog());
System.out.println(log2.getLog());
System.out.println(log3.getLog());
System.out.println(log4.getLog());
}
}
```

If you run this code, you'll see that it prints:

```
[1, 4, 8, 24]
[1, 4, 8, 24]
[1, 4, 8, 24]
[1, 4, 8, 33]
```

indicating that the first three monadic values `log1`

, `log2`

and `log3`

are
equivalent. Looking more closely, this should come as no surprise: `log2`

can
be constructed from `log1`

by applying the third monadic law, whereas `log3`

is
really just `log2`

with weird indentation and line breaks.

In the format used by `log3`

, one can think of the line:

```
pure(a + 3).bind(b ->
```

as meaning "execute the *pure* (i.e. non-monadic) computation `a + 3`

, and
*bind* its result to the variable name *b*". Hopefully this gives a bit more
meaning to the names chosen for `pure`

and `bind`

.

Finally, in `log4`

, we show that this notion of binding variables is not just
for show and we can actually use these bindings in subsequent steps. We also
renamed the last `pure`

to `_return`

, hopefully again illustrating why that
name makes sense for that function.

There you have it. A monad is a type that defines a `bind`

method with
well-defined properties, and these properties allow you to use `bind`

to
construct computations. Not impressed yet? I'll leave you with a final example
of implementing *green threads* in about 60 lines of Java.

```
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.List;
import java.util.LinkedList;
import java.util.Queue;
public interface Fiber<A> {
<B> Fiber<B> bind(Function<A, Fiber<B>> f);
<B> B ifReturn(Function<A, B> f, Supplier<B> s);
Fiber<A> step();
static class Impl {
private Impl() {}
private final static class Return<A> implements Fiber<A> {
public final A a;
public Return(A a) {
this.a = a;
}
public <B> Fiber<B> bind(Function<A, Fiber<B>> f) {
return new Bind<A, B>(this, f);
}
public <B> B ifReturn(
Function<A, B> f,
Supplier<B> s) {
return f.apply(a);
}
public Fiber<A> step() {
return this;
}
}
private final static class Bind<A, B> implements Fiber<B> {
public final Fiber<A> ma;
public final Function<A, Fiber<B>> f;
public Bind(Fiber<A> ma, Function<A, Fiber<B>> f) {
this.ma = ma;
this.f = f;
}
public <C> Fiber<C> bind(Function<B, Fiber<C>> f) {
return new Bind<>(this, f);
}
public <C> C ifReturn(
Function<B, C> f,
Supplier<C> s) {
return s.get();
}
public Fiber<B> step() {
return ma.ifReturn(f, () -> new Bind<A, B>(ma.step(), f));
}
}
}
static <A> Fiber<A> pure(A a) {
return new Impl.Return<A>(a);
}
static <A> List<A> exec(List<Fiber<A>> ls) {
List<A> rets = new LinkedList<>();
Queue<Fiber<A>> q = new LinkedList<>(ls);
while(!q.isEmpty()) {
Fiber<A> f = q.remove();
f.ifReturn(
a -> rets.add(a),
() -> q.add(f.step()));
}
return rets;
}
static <A> Fiber<A> _return(A a) {
return pure(a);
}
static Fiber<Void> debug(String msg) {
System.out.println(msg);
return pure(null);
}
public static void main(String[] args) {
Fiber<Double> f1 =
pure (5) .bind(a ->
debug("thread 1") .bind( _1 ->
pure (a + 3) .bind(b ->
debug("thread 1") .bind( _2 ->
pure (b + 5) .bind(c ->
debug("thread 1") .bind( _3 ->
pure (b * c) .bind(i ->
debug("thread 1") .bind( _4 ->
_return (1.0 * i)
))))))));
Fiber<Double> f2 =
pure (1) .bind(a ->
debug("thread 2") .bind( _1 ->
pure (2) .bind(b ->
debug("thread 2") .bind( _2 ->
pure (1) .bind(c ->
debug("thread 2") .bind( _3 ->
pure (b * b) .bind(b2 ->
debug("thread 2") .bind( _4 ->
pure (a * c) .bind(ac ->
debug("thread 2") .bind( _5 ->
pure (b2 - 4 * ac) .bind(delta ->
debug("thread 2") .bind( _6 ->
pure (Math.sqrt(delta)) .bind(rac ->
debug("thread 2") .bind( _7 ->
_return ((-b + rac) / (2 * a))
))))))))))))));
System.out.println(exec(new LinkedList<Fiber<Double>>() {{ add(f1); add(f2); }}));
System.out.println(exec(new LinkedList<Fiber<Double>>() {{ add(f2); add(f1); }}));
}
}
```

This is obviously highly non-idiomatic Java code. If you're anything like me
when I was first learning about monads, at this point you're probably feeling
like you understand what a monad *is*, and it seems about as cool and practical
as the Y combinator.

This post is long enough as it is, but rest assured that I will try to convince you monads can be a useful practical tool in a future one.

Left as an exercise for the reader.

↩