27 Apr 2015

Yield Return in Java (comment on Benji Webber)

Benji Webber has said that a feature often missed in Java by C# developers is yield return, and considers very complex ways of bringing this into Java for the implementation of iterators and generators. In particular he discusses a threading approach in some detail, which makes it seem really hard to generate and print out the integers from one to five.

In fact, with Java 8, there is no need for any of that, because the required behavior is already built into the lazy execution model of streams. Iterators and generators are part and parcel of the Stream API.

The following code, for example, will print the infinite series of positive numbers:

Stream.iterate(1, x->x+1).forEach(System.out::println);
Throwing in a limit(5) will give you the one-to-five example.

The reason that this works is that each stream element is only generated (lazily) when a downstream method explicitly asks for it. Contrary to appearances, no list of five elements is ever constructed in the snippet

Stream.iterate(1, x->x+1).limit(5).forEach(System.out::println);

To my eyes, the Java 8 code is even nicer to read than the corresponding code in C#.

Easy exhaustive search with Java 8 Streams

I have just been reading this post by Mark Dominus on Haskell. It discusses how the Haskell list monad can be used to hide some of the glue code involved in doing exhaustive searches. Java 8 Streams, which are somewhat similar to Haskell lists in also being monadic, lend themselves to the same style of coding.

The example used in the post I have quoted is the well-known crypt-arithmetics puzzle in which you are asked to find all possible ways of mapping the letters S, E, N, D, M, O, R, Y to distinct digits 0 through 9 (where we may assume that S is not 0) so that the following comes out true:

    S E N D
  + M O R E
  M O N E Y

Here's my Java 8 port of Mark's Haskell example.

public class SendMoreMoney {

    static final List<Integer> DIGITS = unmodifiableList(asList(0,1,2,3,4,5,6,7,8,9));
    public static void main(String[] args) {
        List<String> solutions = 
            remove(DIGITS, 0).stream().flatMap( s ->
            remove(DIGITS, s).stream().flatMap( e ->
            remove(DIGITS, s, e).stream().flatMap( n ->
            remove(DIGITS, s, e, n).stream().flatMap( d ->
            remove(DIGITS, s, e, n, d).stream().flatMap( m ->
            remove(DIGITS, s, e, n, d, m).stream().flatMap( o ->
            remove(DIGITS, s, e, n, d, m, o).stream().flatMap( r ->
            remove(DIGITS, s, e, n, d, m, o, r).stream().flatMap( y ->
                { int send = toNumber(s, e, n, d);
                  int more = toNumber(m, o, r, e);
                  int money = toNumber(m, o, n, e, y);
                  return  send + more == money ? Stream.of(solution(send, more, money)) : Stream.empty();

    static String solution(int send, int more, int money) {
        return "(" + send + "," + more + "," + money + ")";
    static final int toNumber(Integer... digits) {
        assert digits.length > 0;
        return Stream.of(digits).reduce((x,y) -> 10*x + y).get();
    static final List<Integer> remove(List<Integer> xs, Integer... ys) {
        // this naive implementation is O(n^2).
        List<Integer> zs = new ArrayList<>(xs);
        return zs;

The minor optimization of not unncecessarily recomputing "send" and "more" is left out for the sake of readability. The methods remove() - which implements list difference - toNumber(), and solution() have simple implementations. Of these, toNumber() is again a lot like the corresponding Haskell code. Method solution() here returns a String because Java does not have tuples.

Too bad that in Java one must have the nested method calls, but the formatting goes some way to hide this. All in all, I think this is quite nice.

But how fast is it? I did a simple micro-benchmark with JMH 1.9.1 (available from Maven Central) on my laptop computer, which is a quad-core machine with an Intel i7 processor.

Here are the measurement parameters:

# JMH 1.9.1 (released 5 days ago)
# VM invoker: C:\Program Files\Java\jdk1.8.0_25\jre\bin\java.exe
# VM options: -Dfile.encoding=UTF-8
# Warmup: 5 iterations, 1 s each
# Measurement: 25 iterations, 1 s each
# Timeout: 10 min per iteration
# Threads: 1 thread, will synchronize iterations
# Benchmark mode: Average time, time/op

I measured the flatMap solution against the equivalent formulation with eight nested forEach-loops and an external accumulator variable. The flatMap solution is about half as fast. Here's a representative measurement:

Benchmark                        Mode  Cnt    Score   Error  Units
measureFlatMapSearchPerformance  avgt   25  662.377 ± 3.747  ms/op
measureForLoopSearchPerformance  avgt   25  316.105 ± 3.823  ms/op

The nice thing abbout Streams is they are so easily parallelizable. Just throw in a .parallel() in the first line like this:

   remove(DIGITS, 0).stream().parallel().flatMap( s ->

leaving everything else unchanged, and the (parallel) flatMap version becomes twice as fast as the (serial) for-loop version:

Benchmark                        Mode  Cnt    Score   Error  Units
measureFlatMapSearchPerformance  avgt   25  168.278 ± 1.700  ms/op
measureForLoopSearchPerformance  avgt   25  315.806 ± 2.878  ms/op