SyntaxHighlighter

27 Apr 2015

Easy exhaustive search with Java 8 Streams

I have just been reading this post by Mark Dominus on Haskell. It discusses how the Haskell list monad can be used to hide some of the glue code involved in doing exhaustive searches. Java 8 Streams, which are somewhat similar to Haskell lists in also being monadic, lend themselves to the same style of coding.

The example used in the post I have quoted is the well-known crypt-arithmetics puzzle in which you are asked to find all possible ways of mapping the letters S, E, N, D, M, O, R, Y to distinct digits 0 through 9 (where we may assume that S is not 0) so that the following comes out true:

    S E N D
  + M O R E
  ---------
  M O N E Y

Here's my Java 8 port of Mark's Haskell example.

public class SendMoreMoney {

    static final List<Integer> DIGITS = unmodifiableList(asList(0,1,2,3,4,5,6,7,8,9));
    
    public static void main(String[] args) {
        List<String> solutions = 
            remove(DIGITS, 0).stream().flatMap( s ->
            remove(DIGITS, s).stream().flatMap( e ->
            remove(DIGITS, s, e).stream().flatMap( n ->
            remove(DIGITS, s, e, n).stream().flatMap( d ->
            remove(DIGITS, s, e, n, d).stream().flatMap( m ->
            remove(DIGITS, s, e, n, d, m).stream().flatMap( o ->
            remove(DIGITS, s, e, n, d, m, o).stream().flatMap( r ->
            remove(DIGITS, s, e, n, d, m, o, r).stream().flatMap( y ->
                { int send = toNumber(s, e, n, d);
                  int more = toNumber(m, o, r, e);
                  int money = toNumber(m, o, n, e, y);
                  return  send + more == money ? Stream.of(solution(send, more, money)) : Stream.empty();
                }
            ))))))))
            .collect(toList());
           
         System.out.println(solutions);
    }

    static String solution(int send, int more, int money) {
        return "(" + send + "," + more + "," + money + ")";
    }
    
    static final int toNumber(Integer... digits) {
        assert digits.length > 0;
        return Stream.of(digits).reduce((x,y) -> 10*x + y).get();
    }
    
    static final List<Integer> remove(List<Integer> xs, Integer... ys) {
        // this naive implementation is O(n^2).
        List<Integer> zs = new ArrayList<>(xs);
        zs.removeAll(asList(ys));
        return zs;
    }
}

The minor optimization of not unncecessarily recomputing "send" and "more" is left out for the sake of readability. The methods remove() - which implements list difference - toNumber(), and solution() have simple implementations. Of these, toNumber() is again a lot like the corresponding Haskell code. Method solution() here returns a String because Java does not have tuples.

Too bad that in Java one must have the nested method calls, but the formatting goes some way to hide this. All in all, I think this is quite nice.

But how fast is it? I did a simple micro-benchmark with JMH 1.9.1 (available from Maven Central) on my laptop computer, which is a quad-core machine with an Intel i7 processor.

Here are the measurement parameters:

# JMH 1.9.1 (released 5 days ago)
# VM invoker: C:\Program Files\Java\jdk1.8.0_25\jre\bin\java.exe
# VM options: -Dfile.encoding=UTF-8
# Warmup: 5 iterations, 1 s each
# Measurement: 25 iterations, 1 s each
# Timeout: 10 min per iteration
# Threads: 1 thread, will synchronize iterations
# Benchmark mode: Average time, time/op

I measured the flatMap solution against the equivalent formulation with eight nested forEach-loops and an external accumulator variable. The flatMap solution is about half as fast. Here's a representative measurement:

Benchmark                        Mode  Cnt    Score   Error  Units
measureFlatMapSearchPerformance  avgt   25  662.377 ± 3.747  ms/op
measureForLoopSearchPerformance  avgt   25  316.105 ± 3.823  ms/op


The nice thing abbout Streams is they are so easily parallelizable. Just throw in a .parallel() in the first line like this:

   remove(DIGITS, 0).stream().parallel().flatMap( s ->

leaving everything else unchanged, and the (parallel) flatMap version becomes twice as fast as the (serial) for-loop version:

Benchmark                        Mode  Cnt    Score   Error  Units
measureFlatMapSearchPerformance  avgt   25  168.278 ± 1.700  ms/op
measureForLoopSearchPerformance  avgt   25  315.806 ± 2.878  ms/op

No comments:

Post a Comment