SyntaxHighlighter

8 Feb 2016

Recursive Function Memoization with the State Monad

The State Monad

What is the State Monad? If you have been following this blog, you already know the answer. In fact, the parser that we have seen in the previous post is one embodiment of it. In general terms, the state monad is just a glorified function that takes a state and computes from that a result value and some new state. Apart from embedding this function, the state monad has a few bells and whistles that help in combining such functions together. In functional programming languages, the state monad is used to simulate external mutable state. The state monad's role is to pass down the state through the calls of the function. The state parameter disappears from the functions actually used.

You have already seen a practical application of this in our parser. In the parser, the state values have been character sequences. We were able to process the input sequence without ever seeing a corresponding parameter in our methods.

Besides parsing, another often-mentioned application of the state monad is recursive function memoization. In this case, the state will consist of pre-computed function values (a memo).  We'll see how to memoize a recursive function that computes the Fibonacci numbers without tacking an extra memo argument onto the function (such as a HashMap). It is the state monad that will keep a memo of already computed values between recursive calls.

You can google this stuff. People write about it often. For example, here are two articles that both deal with memoizing Fibonacci numbers, in a fashion similar to but not identical with mine: one in Java by Pierre-Yves Saumont and another in Scala by Tony Morris. But both miss a point I am going to make below, namely how to factor out the memoization logic from the function itself. (Saumont, by the way, is in the process of writing what looks to be an interesting book.)

I'll first show you how memoizing Fibonacci numbers is done, and define a generic memoize method that will enable memoization for any function. Only then will I show you how to derive the underlying state monad implementation through a few little refactorings of our SimpleParser. This way, you won't be inundated with details, having no idea what they lead up to. I'm hoping that the code is clear enough to grasp its intent even without knowing the internals. If it doesn't work for you, try reading this post from the end to the beginning.

And in the last section, I'll give you my evaluation of the whole thing.

Memoizing Fibonacci numbers

I expect you all know from your introduction to algorithms course that the following naive version will blow up in your face:

BigInteger fibNaive(int n) {
    if (n <= 2) {
        return BigInteger.ONE;
    }
    BigInteger x = fibNaive(n - 1);
    BigInteger y = fibNaive(n - 2);
    return x.add(y);
} 

One way to make the algorithm feasible is to remember previously computed values.

BigInteger fibMemo(int n, Map<Integer, BigInteger> memo) {
    BigInteger value = memo.get(n);
    if (value != null) {
       return value;
    }
    BigInteger x = fibMemo(n - 1, memo);
    BigInteger y = fibMemo(n - 2, memo);
    BigInteger z = x.add(y);
    memo.put(n, z);
    return z;
}

The above is written in a non-functional style with an explicit extra argument. (I'll name it the explicit version.) You would call it like this, with a memo already containing the first two Fibonacci numbers:

BigInteger fibMemo(int n) {
    if (n <= 2) {
        return BigInteger.ONE;
    }
    Map<Integer, BigInteger> memo = new HashMap<>();
    memo.put(1, BigInteger.ONE);
    memo.put(2, BigInteger.ONE);
    return fibMemo(n, memo);
}

Now wouldn't it be nice if we could get rid of that extra argument and the tedium of looking up and storing values, to derive a memoizing version that is closer to the naive version? And, yes, we can. Simply apply a new function, called memoize, to the original function.

StateMonad<BigInteger, Memo<Integer,BigInteger>> fib(int n) {
    return memoize(this::fib).apply(n-1).bind(x ->              // x = memoize (fib) (n-1)
           memoize(this::fib).apply(n-2).bind(y ->              // y = memoize (fib) (n-2)
           StateMonad.get(_s -> x.add(y))));                    // return x.add(y)
}

Basically, this is only a notational variant of the naive version, a bit difficult to read at first because Java lacks the syntactic sugar other languages have. I have tried to indicate such sugar in the inline comment. Having had time to get used to it, I find it doesn't read too bad, just ignore the clutter of "apply" and "bind", and remember that the assignment is notated at the end of the line instead of the beginning.

Well, we also had to change the return type a bit, it no longer represents a value, but a computation that yields this value (the function represented by the state monad). As you perhaps have guessed, the method bind above is the equivalent of then in our parser: the signature is the same, except the continuation that we put in, and the combined computation that we get out, are not parsers, but a more general StateMonad class that can have a memo as its state. As with the parser, the function inside the monad takes no arguments except the state, all other arguments are supplied beforehand.


My main point is that here we have a very clear separation of concerns: There is one entity memoize, concerned with caching computed values, another entity, the StateMonad, concerned with passing the cached values between function calls, and finally fib itself, that embodies the definition of the Fibonacci numbers and is concerned with making recursive calls and combining their results.This is in contrast to the explicit version and also to both posts quoted above, where these concerns are more intertwined.

For me, this is of the essence of functional programming, and therein lies its beauty: that it is so neat and modular on a small scale.

In order to get a value from the computation that is returned by fib, we need to evaluate it against some suitable initial state. The function that does this is called evalState. (And this you also already know from the parser, where it was called just eval.)

BigInteger fibMonadic(int n) {
    return n <= 2 ? BigInteger.ONE
                  : fib(n).evalState(new Memo<Integer,BigInteger>().insert(1, BigInteger.ONE).insert(2, BigInteger.ONE));
}

The class Memo is intended to be a functional equivalent to HashMap.

Implementing generic memoization

So, what is memoize? The signature of memoize is obvious. It takes a function of the type of fib and returns another function of the same type. We can abstract over the input and result value types Integer and BigInteger of fib and replace them with generic type parameters T and R.

Now think of what memoize has to do. It takes a function as its parameter. It must return a new function that
  • when given an argument, tries to look up a previously computed value from the memo
  • if it finds the value, returns a function that will yield this value
  • otherwise, applies the given function to the given argument, stores that result in the memo, and returns a function will yield this value plus an updated state 
As usual with functional maps, the memo shall return an Optional upon lookup, so that we can make the case distinction with Optional's methods map and orElse.Without further ado, here's the code:

static <T,R> Function<T, StateMonad<R, Memo<T,R>>> memoize(Function<T, StateMonad<R, Memo<T,R>>> f) {
    return t -> {   
                  StateMonad<Optional<R>, Memo<T,R>> s = StateMonad.get(m -> m.lookup(t)); // create a computation that would try to find a cached entry
                  return s.bind(v ->                                                       // perform the computation, call the result v and do: 
                              v.map(s::unit)                                               // if value is present, return a computation that will yield the value
                               .orElse(                                                    // otherwise  
                                     f.apply(t).bind(r ->                                  //   compute r = f(t)  
                                     s.mod(m -> m.insert(t, r)).map(_v -> r)               //   apply a function to the memo that stores r in it, set the value of s to r
                              )));                                                         //   return a computation that will yield r
                };
}; 


The appeal of memoize is that it is completely general and can be applied to any unary function.

What about functions that take multiple arguments? Well, in Java you cannot abstract over functions with arbitrary arity. I suggest to take a suitable class, like the one from the wonderful jOO╬╗ library, that can represent an n-ary tuple, and treat an n-ary function as a unary function of such a tuple. (They should have added tuples to Java, every functional programming person is asking for them.)

Implementing the state monad

The state monad itself is not so interesting. As I have noted above, it should be self-explanatory to readers familiar with the simple parser. Here's how you might change the parser to derive the monad in a few simple steps, mainly consisting of renamings:
  • Abstract over the type of the state (replace CharSequence with a type variable S)
  • Rename a few methods to what's customary in the functional world (e. g. Haskell)
    • eval to evalState
    • parse to runState
    • then to bind
  • Delete many, many1, and orElse. (We don't need orElse in our example, but you might nevertheless decide to keep it. I guess that in some contexts, it might be useful for backtracking.)
  • Simplify evalState by removing the checking for unused input that was completely specific to parsing
  • You might also wish to rename a few variables. (I renamed inp for "input" to s for "state", p for "parser" to m for "monad" etc.)
All that remains is to define the new state-manipulating method mod and the static value-extracting utility get, which are used in the definition of memoize. That is very little work, and here is the complete resulting code. (I've left out the overloaded version of bind/then that we don't need for this example. Furthermore, the state monad is usually given a few more convenience methods, which we also don't need and which I'm not going to discuss here.)

@FunctionalInterface
public interface StateMonad<T,S> {
    
    default T evalState(S s) {
        Objects.requireNonNull(s);
        Tuple<T, S> t = runState(s);
        if (t.isEmpty()) {
            throw new IllegalArgumentException("Invalid state: " + s);
        }
        return t.first();
    }

    /** 
     * The type of the functional interface. A state monad is an abstraction over a function:
     * StateMonad<T,S> :: S -> Tuple<T, S> 
     * In other words, a state monad represents a stateful computation that derives a value and 
     * some new state from an input state.
     */
    abstract Tuple<T, S> runState(S s);
    
    // Monadic operations

    default <V> StateMonad<V,S> unit(V v) {
        return inp -> tuple(v, inp);
    }

    default <V> StateMonad<V,S> bind(Function<? super T, StateMonad<V,S>> f) {
        Objects.requireNonNull(f);
        StateMonad<V,S> m = s -> {
            Tuple<T, S> t = runState(s);
            if (t.isEmpty()) {
                return empty();
            }
            return f.apply(t.first()).runState(t.second());
        };
        return m;
    }
    
    default <V> StateMonad<V,S> map(Function<? super T, V> f) {
        Objects.requireNonNull(f);
        Function<V, StateMonad<V,S>> unit = x -> unit(x);
        return bind(unit.compose(f));
    }
       
    // Additional functions
       
    /** Modify the current state with the given function. */
    default StateMonad<T, S> mod(Function<S, S> f) {
        return s -> runState(f.apply(s)); 
    }

    /** Create a computation that extracts a value from the state with the given function. */
    static <V, S> StateMonad<V, S> get(Function<S, V> stateProjector) {
        return s -> tuple(stateProjector.apply(s), s);
    }
}

Of course we could have inlined that static call to the convenience function StateMonad.get in memoize (i. e. used the lambda directly), but that would have made the use of tuples visible in memoize.

Discussion

The trick I've shown you in this post is certainly appealing in its cleverness, and sometimes I quite admire it. But most of the time I feel that it is too clever by half. You've got to think practically. The functional version is (in Java) no better with regard to conciseness or readability than the explicit version. The effort to rewrite the naive definition to make use of memoization is also about the same in each case, because you have to make all those additional syntactic changes in addition to just throwing in a call to memoize. And as for performance, the monadic solution is about 5 times slower on my machine than the explicit one (for 100 ≤ n ≤ 1000).

Then there is the matter of representing state. In OO, state is most naturally represented as a member variable of a class. Here's equivalent coding that fits into this paradigm.(Well, almost equivalent. This solution is not thread-safe. On the other hand one might reuse the Memoizer instance.)

static class Memoizer<T, R> {
  private final Map<T, R> memo;

  public Memoizer(Map<T, R> memo) {
    this.memo = memo;
  }

  public Function<T, R> memoize(Function<T, R> f) {
    return t -> {
      R r = memo.get(t);
      if (r == null) {
        r = f.apply(t);
        memo.put(t, r);
      }
      return r;
    };
  }
}


Memoizer<Integer, BigInteger> m = new Memoizer<>(new HashMap<>());

BigInteger fib(int n) {
  if (n <= 2) return BigInteger.ONE;
  return m.memoize(this::fib).apply(n - 1).add(
         m.memoize(this::fib).apply(n - 2));
} 

And of course there is a simple, idiomatic, non-recursive, constant-space solution with mutable variables:

BigInteger fib(int n) {
 if (n <= 2) return BigInteger.ONE;
 
 BigInteger x = BigInteger.ONE;
 BigInteger y = BigInteger.ONE;
 BigInteger z = null;
 for (int i = 3; i <= n; i++) {
  z = x.add(y);
  x = y;
  y = z;
 }
 return z;
} 

So be careful when you take over an idea from functional programming. You may be able to retain some of the beauty, but in other respects you won't always win. In fact, sometimes you'll lose. Don't overdo it. Functional thinking may offer you a different - and useful - perspective on a problem, as I hope to have shown with the parser and some other posts. But a functional solution is not guaranteed to be per se more readable, more maintainable, or more efficient than a good old vanilla object-oriented or imperative solution. Often, it will be neither. Java is not a functional programming language!

Addendum 

For the sake of completeness, here's the Memo class that I have used for demo purposes. It's a mutable map masquerading as a functional data structure. Very bad. Don't copy it. You may put trace statements in these two methods to see in what order elements are computed and retrieved, and that they are indeed computed only once. The method names are the same as in TotallyLazy's PersistentMap.

class Memo<K, V> extends HashMap<K, V> {

    public Optional<V> lookup(K key) {
        Optional<V> value = Optional.ofNullable(get(key));
        return value;
    }

    public Memo<K, V> insert(K key, V value) {
        put(key, value);
        return this;
    }
}

No comments:

Post a Comment