Wednesday, April 20, 2011

Ugly memoization

Here's a problem that I recently ran into. I have a function taking a
string and computing some value. I call this function a lot, but a
lot of the time the argument has occurred before. The function is
reasonably expensive, about 10 us. Only about 1/5 of the calls to the
function has a new argument.

So naturally I want to memoize the function. Luckily Hackage has a
couple packages for memoization. I found
data-memocombinators
and
MemoTrie
and decided to try them.
The basic idea with memoization is that you have a function like

memo :: (a->b) -> (a->b)

I.e., you give a function to memo and you get a new function of the
same type back. This new function behaves like the original one, but
it remembers every time it is used and the next time it gets the same
argument it will just return the remembered result.
This is only safe in a pure language, but luckily Haskell is pure.
In an imperative language you can use a mutable memo table that stores
all the argument-result pairs and updates the memo table each time the
function is used. But how is it even possible to implement that in a
pure language? The idea is to lazily construct the whole memo table
in the call to memo, and it will then be lazily filled in.
Assume that all values of the argument type a can be
enumerated by the method enumerate, we could then write
memo like this:

Note how the memo table is constructed given just f, and this memo
table is then used in the returned function.
The type of this function would be something like

memo (Enumerate a, Eq a) => (a->b) -> (a->b)

assuming that the class Enumerate has the magic method enumerate.

This just a very simplified example, if you tried to use this it would
be terrible because the returned function does linear lookup in a
list. Instead we want some kind of search tree, which is what the two
packages I mention implement. The MemoTrie package does this in a
really beautiful way, I recommend reading Conal's blog post about it.
OK, enough preliminaries.
I used criterion to perform the benchmarking, and I tried with no
memoization (none), memo-combinators (comb), and
MemoTrie (beau). I had a test function taking about 10us,
and then i called this functions with different number of repeated
arguments: 1, 2, 5, and 10. I.e., 5 means that each argument occurred
5 times as the memoized function was called.

1

2

5

10

none

10.7

10.7

10.7

10.7

comb

62.6

52.2

45.8

43.4

beau

27.6

17.0

10.4

8.1

So with no memoization the time per call was 10.7 us all the time, no
surprise there. With the memo combinators it was much slower than no
memoization; the overhead for looking something up is bigger than the
cost of computing the result. So that was a failure. The MemoTrie
does better, at about an argument repetition of five it starts to
break even, and at ten it's a little faster to memoize.

Since I estimated my repetition factor in the real code to be about
five even the fastest memoization would not be any better then
recomputation. So now what? Give up? Of course not! It's time to
get dirty.

Once you know a function can be implemented in a pure way, there's no
harm in implementing the same function in an impure way as long as it
presents the pure interface. So lets write the memo function the way
it would be done in, e.g., Scheme or ML. We will use a reference to
hold a memo table that gets updated on each call. Here's the code,
with the type that the function gets.

The memoIO allocated a reference with an empty memo table.
We then define a new function, f', which when it's called
with get the memo table and look up the argument. If the argument is
in the table then we just return the result, if it's not then we
compute the result, store it in the table, and return it.
Good old imperative programming (see below why this code is not good
imperative code).

But, horror, now the type is all wrong, there's IO in two places.
The function we want to implement is actually pure. So what to do?
Well, if you have a function involving the IO type, but you can prove
it is actually pure, then (and only then) you are allowed to use
unsafePerformIO.

I'll wave my hands instead of a proof (but more later), and here we go

Wow, two unsafePerformIO on the same line. It doesn't get
much less safe than that.
Let's benchmark again:

1

2

5

10

none

10.7

10.7

10.7

10.7

comb

62.6

52.2

45.8

43.4

beau

27.6

17.0

10.4

8.1

ugly

13.9

7.7

3.9

2.7

Not too shabby, using the ugly memoization is actually a win already
at two, and just a small overhead if the argument occurs once. We
have a winner!

No so fast, there's

A snag

My real code can actually be multi-threaded, so the memo function had
better work in a multi-threaded setting. Well, it doesn't. There's
no guarantee about readIORef and writeIORef when
doing multi-threading.
So we have to rewrite it. Actually, the code I first wrote is the one
below; I hardly ever use IORef because I want it to work with
multi-threading.

So now we use an MVar instead. This makes it thread safe.
Only one thread can execute between the takeMVar and the
putMVar. This guarantees than only one thread can update the
memo table at a time. If two threads try at the same time one has to
wait a little. How long? The time it takes for the lookup, plus some
small constant. Remember that Haskell is lazy, the the (f x)
is not actually computed with the lock held, which is good.
So I think this is a perfectly reasonable memoIO. And we can
do the same unsafe trick as before and make it pure. Performance of
this version is the same as with the IORef.

Ahhhh, bliss.
But wait, there's

Another snag

That might look reasonable, but in fact the memo function is
broken now. It appears to work, but here's a simple use that fails

sid :: String ->; String
sid = memo id
fcn s = sid (sid s)

What will happen here? The outer call to sid will execute
the takeMVar and then do the lookup. Doing the lookup with
evaluate the argument, x. But this argument is another call
to sid, this will try to execute the takeMVar.
Disaster has struck, deadlock.

What happened here? The introduction of unsafePerformIO
ruined the sequencing guaranteed by the IO monad that would have
prevented the deadlock if we had used memoIO. I got what I
deserved for using unsafePerformIO.

Can it be repaired? Well, we could make sure x is fully
evaluated before grabbing the lock. I settled for a different repair,
where the lock is held in a shorter portion of the code.

This solution has its own problem. It's now possible for several threads
to compute (f x) for the same x and the result of
all but one of those will be lost by overwriting the table. This is a
price I'm willing to pay for this application.

Moral

Yes, you can use imperative programming to implement pure functions.
But the onus is on you to prove that it is safe. This is not as easy
as you might think. I believe my final version is correct (with the
multiple computation caveat), but I'm not 100% sure.

I was more worried that your first definition didn't permit memoization of recursive functions (as it would deadlock during the recursive call)!

Now: how do you do evaluate-just-once memoization that permits recursion or parallelism? With the pure solution it's easy but evidently slow. It seems like we ought to be able to insert a thunk and then return the retrieved thunk, evaluating it after we're done mutating the table.

@Jan Recursion is not as bad as you might think in the first MVar version. The only crucial evaluation happening with the lock held is of the argument to the function. The actual call to f is just a thunk.

Have you released this library on hackage? Really fast (and not just beautiful) memoization seems like a worthy endeavor, and I have some code which I'd like to try it on... I've tried both MemoTrie and memocombinators and gotten some good speedups, but it's still not as fast as I like and the main bottleneck is in the memo code.

@Francisco It's not straight forward. First, to memoize you need to be able to build some kind of table indexed by the function argument. How would do that? Haskell values does not come with any equality or ordering or hashing. And even if they did, how far would you evaluate the argument?

But much worse, there's a cost associated with memoization. Both in time for the lookup and insertion in the table, and also in space for storing the table. This isn't a cost you want to pay by default.

Could you please elaborate on what might go wrong with the IORef version in a multithreaded program? Is it just that the function might be recomputed on the same arguments multiple times, or could something worse happen?

@Bernie From what I understand, there are no atomicity guarantees whatsoever for IORef, meaning that not even read & write are guaranteed to be atomic. So then even those could give the wrong results. I'm not sure if that currently happens on any platform.

The documentation says atomicModifyIORef is the safe way to access and modify an IORef in a muti-threaded setting.

@francisco: what got me confused in the past is the distinction between memoizaion (you call a function twice with the same arguments, and it is only evaluated the first time) and the funny treatment of names in haskell: when you write something like "x = f 4 2", you cannot change the meaning of x in the future (you can just define a new x to shadow the old one). the meaning is not calculated where it is defined, but lazily, where it is required for the result of the program. the analogy to memoization is that you can use it several times and it is only calculated once. the difference is that it's a name, not a complex expression, ie. a function call.

Excuse me for this simple question. Does the fact that the Map is referenced by an IORef make it mutable? in other words does the insert operation create a new map (copies memory) or it just modifies the old one?

The arguments to a function are thunked, meaning the arguments get evaluated only once, and only when they are needed inside the function. This is not the same as checking if a function call was previously called with the same arguments.

If the argument is a function, the thunk will call it once without checking if that function had been called else where with the same arguments.

Thunks are conceptually similar to parameterless anonymous functions with a closure on the argument, a boolean, and a variable to store the result of the argument evaluation. Thus thunks incur no lookup costs, because they are parameterless. The cost of the thunk is the check on the boolean.

Thunks give the same amount of memoization as call-by-value (which doesn't use thunks). Neither call-by-need nor call-by-value memoize function calls. Rather both do not evaluate the same argument more than once. Call-by-need delays that evaluation with a thunk until the argument is first used within the function.

Apologies for being so verbose, but Google didn't find an explanation that made this clear, so I figured I would cover all the angles in this comment.

Is this implementation asynchronous-exception-safe? I'm worried there could be an asynchronous exception between the takeMVar and the putMVar, and the MVar would be empty forever. Perhaps we need to use protect or modifyMVar, as in http://community.haskell.org/~simonmar/par-tutorial.pdf