A State Monad Tutorial

Code that is intended to be typed in a ghci session or output from ghci is contained in a

green background.

Code that is intended to be loaded from a source file by ghci or other Haskell code is contained in a

blue background.

Monads tend to be a difficult concept for beginners to understand. In this guide, I try a different approach by explaining non-monadic values and building from there. If you want to understand monads, please don’t breeze through this tutorial, as you will need to think to understand them.

I will first explain the State type. Let us import it and look at its type:

> :m Control.Monad.State.Strict

For performance reasons, we will almost always want to use the strict version. Often, the functionality of the lazy version will only be useful when we do some debugging.

> :i State

newtype State s a = State {runState :: s -> (a, s)}

There should be more lines, but for the purposes of this tutorial, we only care about the first line, which tells us the type of State constructor.
Even though State is defined in a sub-module of Control.Monad, we will not be using its monadic properties yet.

In case you’re confused about the type, we’ll go over it. To make things less confusing, we’ll give the data constructor a different name from the type constructor, for now:

newtype StateType s a = State {runState :: s -> (a, s)}

As we don’t need the functionality or flexibility of “data”, we use newtype to avoid the overhead “data” would give us, although data would work too. If it helps you more clearly understand it, read “newtype” as “data”.

It looks like State wraps a function. Yes, that is what it does. It just wraps a function that takes an initial thing of type s, and returns a tuple that contains some other value that can be pretty much anything and a new value of the same type s.

Let us review the brackets. All it is is record syntax, which eliminates the necessity of pattern matching by providing us with convenient functions automatically. Here’s an example of how Maybe could be written with record syntax:

We can define a counter whose value is 42 and whose value has been incremented 3 times like this in ghci:

> let ourCounter = Counter 42 3

Or, if you prefer the other syntax:

> let ourCounter = Counter {value = 42, numberOfIncrements = 3}

On a side note, we can construct a new counter by modifying an existing counter:

> ourCounter{value = 24}
Counter {value = 24, numberOfIncrements = 3}

Let us try writing a State function. Start a text editor, and just write the import statement and one function. We have an example of one such function here, but, for experience, you should write your own:

incrementCounter is a function that takes takes an initial counter, and returns the value of the initial counter and an incremented counter, both contained in a tuple. It’s important that you understand what the record syntax does here: oldValue and oldNumOfIncrements are set to the value and the number of increments of the initial counter, and the latter usage is simply constructing a new counter.

In case you have not yet been exposed to the “succ” function, which, in this case, returns the successor to our Int; know that we also could have replaced “succ” with “+ 1”:

The only difference between incrementCounterState and incrementCounter is that incrementCounterState is wrapped in the State constructor, and in order to get our State function, we just call runState on incrementCounterState.

Uh oh! Can you figure out why this is happening? GHCi wants (expects) incrementCounterState to be a function (that takes a Counter and returns something that GHCi doesn’t care what it is), but incrementCounterState is really of type “StateType Counter Integer”. We forgot to unwrap the value contained in the State value! Since incrementCounterState is really just a value that contains a function, we’ll unwrap that value to get that function by calling runState. Let’s see what the type of runState incrementCounterState is:

Hooray! it worked! We incremented our counter and retrieved the value of the first counter!

Right, so, States just pretty much wraps a function that takes some value, which we call the initial state, and returns a value of *any* type and another “state” that needs to have the same type as the initial state together in a tuple.

Up to this point, we have not used any of the monadic properties of State. Before we step there, let’s first exercise our skills with State a bit.

How about a function that doubles the value of a counter and increments its numberOfIncrements? Can you write that?

Up to this point, we have not used any of the monadic properties of State. Let’s jump right in! For now, I won’t worry you with do notation or the usage of “gets”, “put”, “>>”, and the like; we will only focus on bind (>>=) and return.

What is (>>=)? Let us look at its type:

:t (>>=)
(>>=) :: (Monad m) => m a -> (a -> m b) -> m b

It takes a monad/value (which is, in our case, a State computation / function that is wrapped with the State constructor), and a function, and returns a monad that “bound” those two together. The function just takes an unwrapped value and returns a new monad. It depends on the monad for how values are unwrapped. It is up to the monad to define a way for bind to unwrap the value of the first value.

Let’s put this in the context of State:

(>>=) :: (State s) a -> (a -> (State s) b) -> (State s) b

Which can be written as

(>>=) :: State s a -> (a -> State s b) -> State s b

Notice how we just replaced m with (State s).

So, (>>=) pretty much combines two monads together. How does this work for State? Let’s take a look at its implementation along with return:
Since the bind function (and return) are functions of the Monad class, the definition needs to begin with an “instance” line:

But, yikes! That piece of code can look confusing, because it certainly confused me! Well, it takes one State, which can be incrementCounterState, and a function. What does that function do, exactly? We know it needs to give us the final State computation, which is a function that takes an initial state and returns a value and a new state. At some point, you should think about State’s bind and think about how it works.

We already have a State computation, incrementCounterState, which increments a counter. Can we write a function that increments a counter twice? Sure! Besides the obvious method of writing a new state computation and using succ twice, we can combine incrementCounterState:

(Again, the indentation of incrementCounterTwice’ is a bit awkward, so you might want to pay attention to the parethesis. As long as you understand the first State computation, it is not necessary that you understand the other one)

So, using (>>=), we were able to combine two States (or State computations), which happen to be incrementCounterState, and end up with another monad. In this example, we ignored theValueOfTheOriginalCounter, although we could have used it for whatever purpose.

This our first time using bind, so let’s consider take a good look at what we’re doing. By binding incrementCounterTwice with a function, which ignores the result of incrementCounter when it is applied and returns a State computation, State, we end up with another State.

We can even combine incrementCounterTwice and another State computation:

In this monad, a value of type GameState is being threaded through. foo can be of type (State GameState a), where a can be anything, as long as it is the same type that bar takes as the first parameter. Bar would be of type (a -> State GameState b). The return type of bar can differ from the return type of foo. bar, when given an initial unwrapped state value, would return a returned value and a new state. quux would be called with the returned value, and then return a State computation. But, quux needs to be of type State GameState String, since the returned value type of aCombinedThing needs to be the same as the returned value type of quux.

In case you’re confused about how this works, think of it as a wrapped function that takes an initial unwrapped state value, passes it onto foo, then foo returns the new state and passes that state onto bar, and then the new state of that is passed to quux. Take that, but have the return value be passed to the each function after foo.

What is this “>>” function? It’s actually quite simple, as it behaves just like >>= only it doesn’t pass the returned value of the first monad. (This function is actually part of the Monad class, but its default definition should never be changed most of the time. There is actually one more function in the Monad class, known as fail, but it’s pretty unpopular, and it’s commonly suggested not to use it, so I won’t explain it). It can be defined as:
a >> b = a >>= \_ -> b

This will, when passed an initial state value, pass the state onto foo, discard foo’s returned value, but take the new state and give it to bar, discard bar’s returned value, and pass bar’s new state to quux. quux returns a returned value of type String and a new GameState. Sometimes, State will not even need a return value, or it won’t make sense; we can use Haskell’s void, known as unit, which is written as ().

In this case, we really don’t care about the returned values of any of these. Since the returned values of foo and bar are discarded (even if any are “()”), they can be of any type, but since the type of aCombinedThing is State GameState (), the returned value type of quux needs to be (), so quux is of type State GameState () too.

How about a function that increments a counter twice, but uses the intermediate value (the returned value after the initial counter was incremented once, which is the initial value of the initial counter because of how we defined our State computations) to do something, perhaps constructing a GameState from that value at the end? We’ll do something even simpler: we will write a State computation that increments a counter twice; the return value of this computation is going to be the string representation of the intermediate counter value.

Notice that the type of the returned value of incrementCounter is not String, but is Int, but that’s OK because it’s just passed to our lambda function, which assigns the name “initialCounterValue” to it. When the state finally reaches our last State computation, the final State computation does need to return a returned value of type String, which it does.

So we see that kindOfIncrementCounter is a State computation that, when “run” with an initial state, will increment it twice and return it and the string representation of the initial counter value. initialCounterValue points to the returned value of incrementCounter after it is called; remember that we defined incrementCounter to be the first value of the counter it takes?

What is return? Return turns a value into a monad. It wraps a value into a monad, lifts a value into a monad, raises a value into a monad, or injects a value into a monad; depending on your preferred terminology. Which State computations does “return 3” produce? Let’s see:

It seems that “return x” creates a State computation that leaves the state unchanged (the state might be passed to a later state computation) and returns it with the returned value of x.

How can we be sure? By looking at how return is defined for the State monad, of course!

instance Monad (State s) where
return a = State $ \s -> (a, s)
…

We were right! “return x” wraps a function that returns whatever is passed to it in addition to x under the State constructor.

Now, let’s look at three useful functions and how they’re defined for the State monad, before we learn about “do” notation.

get :: State s s
get = State $ \s -> (s, s)

Keep in mind there’s nothing magical about these. Even though they are indirectly defined in Control.Monad.State.Class, which is exported by Control.Monad.State.Lazy and Control.Monad.State.Strict; we can still purely define them ourselves, perhaps as I have done here.

get takes a state, leaves the state unchanged, and also returns that state as the returned value. This is useful because we can pass the current state to the next function. Because we can pass the current state to the next function, we can also use that state later.

Consider writing a State computation that can increment a counter, halve its value, increment the counter, and return the value of the counter after the first increment (along with the final state). We can do this without any helper functions:

Sure, we could have avoided referring to a previous state by using the returned value of incrementCounter, which is all that we want out of thatOneState, instead of our own State computation here, which does the very same thing except it doesn’t return anything useful; but for the intents of this part of the tutorial, we’ll pretend that we never defined incrementCounter to return that.

Let’s simplify the code by substituting our common code by a simple value (function wrapped in State constructor):

Even though, unlike our previous code, incrementCounter does return a meaningful value, the value of the initial counter; the functionality is not different, since the return value is discarded. Do you remember a function that discards the unwrapped value of a monad? Yes, (>>)! Let’s try that:

gets takes a function, and behaves like get, only it applies the function to the state. This is especially useful when our state is a type defined with record syntax, as we can extract certain things contained in the state, such as Counter. Let’s use gets instead of manually applying “value” to get the value out of thatOneState:

modify takes a function and returns a State computation that takes an initial state, and returns a tuple of nothing/void/unit/() and the state applied to the function. Modify modifies the state. Since nothing really meaningful can come out of this, the returned value of this computation is unit (“()”), to which the only response that would make sense would be ignoring it, as only one value is of type “()”.

We have a State computation that does what modify does, so we can simplify our function further:

In case it isn’t obvious, put is a function that takes a state and returns a State computation, that takes an initial state, ignores it, and sets a new state that will continue in the chain or thread of State computations. Since no return value would really make sense, it doesn’t have one.

It seems that there are common patterns with Monads. This is indeed the case, and Haskell acknowledges this by supplying the programmer with the convenience known as do notation.

do blocks are simply “syntactic sugar”, meaning that they are directly translated into monads using (>>=) and (>>). The translation is simple:

do foo
bar

translates to

foo >>
bar

What about (>>=)?

Well,

do returnedValue <- foo
bar

translates to

foo >>=
\returnedValue -> bar

Let’s try simplifying our State computation, incrementHalveAndIncrementCounter, using do notation:

Much, much nicer! We have simplified our code using (>>), helper functions, and do notation to the point at which it is effortlessly readable by a person who has not written our code.

My goal is to teach others to help them understand the State monad. I hope that a more “conversational” style of writing has helped me achieve my goal.

By now, you should be able to figure out how the State monad works. After this point, you should be able to easily understand the other monads. I hope you will be as overjoyed as I was when I finally understood Monads! Good luck!