SyntaxHighlighter

14 Jun 2013

Higher-Order Functions in Java 8 (Part 2) - Laziness

In this post I'd like to show that the typical elegance that results from the combination of lazy evaluation and higher-order functions in functional programming languages can sometimes also be had in Java. As our example, let's look at generating the first n primes, for some arbitrary n. We'll use the method of trial division, which determines the primality of a number k by testing whether it is divisible by any prime smaller than k.We will see how to use Java lambda expressions to generate an infinite sequence of integers and filter out the non-primes from it.

Of course, we will not really generate an infinite sequence, but rather an unbounded one, and limit ourselves to taking out n elements. This is where lazy evaluation comes in. Being lazy means computing a value only when it is actually needed. (The opposite - immediate evaluation - is called being strict. The article Functional Programming For the Rest of Us, which offers an elementary and informal introduction to functional programming concepts from a Java perspective, contains, among other things, a discussion of laziness and its advantages as well as disadvantages.)

Some functional languages like Haskell are inherently lazy, others like Scala have both kinds of evaluation. Java 8 also falls into this latter category: Traditionally, Java is a strict language: all method parameters are evaluated before a method is called etc. However, the new Streams API is inherently lazy: An object is only created from the stream when there is demand for it, i. e. only when some terminal operation like collect or forEach is called. And lambda expressions allow passing functions as arguments for delayed evaluation.

Let's return to our example. In Scala, we might write the following code, lifted from Louis Botterill's blog:

def primes : Stream[Int] = {  
   var is = Stream from 2  
   def sieve(numbers: Stream[Int]): Stream[Int] = {  
    Stream.cons(  
     numbers.head,  
     sieve(for (x <- numbers.tail if x % numbers.head > 0) yield x))  
   }  
   sieve(is)  
  }   

By the way, this algorithm is not the Sieve of Eratosthenes, although it is often so presented. This paper by Melissa O'Neill has all the details. It's an interesting read: O'Neill describes an optimized implementation of the genuine sieve, making use of heap-like data structures in Haskell. I'd be interested to know if someone has ported this to Java 8. We'll stick with the unfaithful sieve (O'Neill 's term) for now.

What's important in the above code is the ability to access the head of the stream and still lazily operate on the rest.Unfortunately, in Java findFirst is also one of those terminal operations on streams, so there is no obvious way to port this bit to Java without creating our own lazy list structure. We'll do that by building on the simple ConsList implementation alluded to in the previous post, and extending it in several ways:
  • add a factory method that accepts a Function<T,T> to generate the elements in the tail  in sequence
  • add a factory method that accepts a Supplier<FilterableConsList<T>> to create a tail
  • add an instance method that accepts a Predicate<T> to filter out non-matching elements from a list
  • make the implementation class lazy for the tail (but strict for the head).
As a further goodie, we can easily implement the collection interface by providing an iterator and make the thing streamable by providing a spliterator of unknown size that is based on the iterator. The coding is lengthy but straightforward, it is shown at the bottom of this post. With it, the driver for generating the first 100 primes becomes:
 LazyConsList<Integer> s = sieve(LazyConsList.newList(2, x -> x + 1));   
 s.stream().limit(100).forEach(x->System.out.print(x + " "));   
where the interesting sieve-method is this:
 private LazyConsList<Integer> sieve(FilterableConsList<Integer> s) {  
  Integer p = s.head();  
  return cons(p, () -> sieve(s.tail().filter(x -> x % p != 0)));  
 }
Now I think that looks pretty neat (and very similar to the corresponding Scala code). That's all I wanted to show. However, as Java 8 is still very new, it may not be quite obvious to many readers what's going on here. Let me try to explain.

The text declares very clearly what it takes to output the first 100 primes: Create a list of integers starting at two, sieve it, take the first 100 elements and output them severally, where sieving consists of taking the first element of the list, then taking the rest disregarding any multiples of the first element, and sieving that again.

Procedurally, the picture is very different: some method calls are deferred until the tail of a cons-list is accessed. In particular, there is no infinite loop, because every call to sieve in the expression () -> sieve(...)is delayed until forEach requests another prime for output. Here's the sequence:
  1. forEach causes the stream to get the next element from its underlying iterator
  2. whereupon the iterator will return the first list element and change its own state by retrieving the tail of its list, which will be a list based on a Supplier, namely the one which came out of the previous call to sieve
  3. the iterator's call to tail will in turn lead to strict evaluation of the method arguments to sieve, accessing the tail of the generator-based list and creating a new list from it by adding a new filter
  4. so that when sieve is executed, a new cons-cell with a supplied tail is constructed, the head of which is the first element of the filtered list from step 4 and therefore a prime
  5. which will be the next element returned from the stream when processing continues at step 1.
Contrary to appearances (and to the suggestive wording of the declarative meaning of the program)  no list containing 100 primes is ever constructed and then output. Instead, you might say that the processing takes place "vertically", constructing a series of one-element lists and outputting their heads. You can see that clearly in a debugger, or by including some logging when entering and leaving the sieve.

-- Sebastian

Here's the complete code for our lazy list. I'm sure there's much to improve. One thing is that the toString() method that is inherited from the superclass is based on the list's iterator and will not terminate for an infinite list. I'd be grateful for suggestions or error reports (perhaps someone might like to write some unit tests?). First the interface:
 
import java.util.function.Predicate;

/**  
  * Interface for a list that may be filtered.  
  * @param <T> element type  
  * @author Sebastian Millies  
  */  
 public interface FilterableConsList<T> extends ConsList<T> {  
   
    /**  
    * Applies the given filter to this list.  
    *  
    * @param f a filter predicate  
    * @return a list containing the elements of this list for which the filter   
    * predicate returns <code>true</code>)  
    */  
   FilterableConsList<T> filter(Predicate<T> f);  
     
   /**  
    * Returns the tail of this list.  
    *  
    * @return tail  
    * @throws EmptyListException if the list is empty  
    */  
   @Override  
   FilterableConsList<T> tail();  
 }
And then the implementation:
 import java.util.AbstractCollection;  
 import java.util.Iterator;  
 import java.util.Objects;  
 import java.util.Spliterator;  
 import java.util.Spliterators;  
 import java.util.function.Predicate;  
 import java.util.function.Supplier;  
 import java.util.function.UnaryOperator;  
   
 /**  
  * An immutable, filterable list that can be based on a generator function and  
  * is strict for the head and lazy for the tail. The tail can also be computed  
  * on demand by a Supplier instance. Via the Collection interface, this class   
  * is connected to the streams API.  
  *  
  * @param <T> element type  
  * @author Sebastian Millies  
  */  
 public abstract class LazyConsList<T> extends AbstractCollection<T> implements FilterableConsList<T> {  
   
   // ------------- Factory methods  
   
   /**  
    * Create a non-empty list based on a generator function. The generator is  
    * evaluated when the first element in the tail of the list is accessed.  
    * @param <T> element type  
    * @param head the first element  
    * @param next the generator function.  
    * @return a new list with head as its head and the tail generated by next  
    */  
   public static <T> LazyConsList<T> newList(final T head, final UnaryOperator<T> next) {  
     Objects.requireNonNull(next);  
     return new FunctionTail<>(head, next, x->true);  
   }  
   
   /**  
    * Create a list containing the given elements.  
    *  
    * @param <T> element type  
    * @param elems elements  
    * @return a new list  
    */  
   public static <T> LazyConsList<T> newList(T... elems) {  
     LazyConsList<T> list = nil();  
     for (int i = elems.length - 1; i >= 0; i--) {  
       list = cons(elems[i], list);  
     }  
     return list;  
   }  
     
   /**  
    * Add an element at the front of a list.  
    * @param <T> element type  
    * @param elem new element  
    * @param list list to extend  
    * @return a new list with elem as its head and list as its tail  
    */  
   public static <T> LazyConsList<T> cons(T elem, FilterableConsList<T> list) {  
     Objects.requireNonNull(list);  
     return new Cons<>(elem, list, x->true);  
   }  
   
   /**  
    * Add an element at the front of a list that will be supplied by the given supplier.  
    * @param <T> element type  
    * @param elem new element  
    * @param listSupplier list to extend  
    * @return a new list with elem as its head and list as its tail  
    */  
   public static <T> LazyConsList<T> cons(T elem, Supplier<FilterableConsList<T>> listSupplier) {  
     Objects.requireNonNull(listSupplier);  
     return new SuppliedTail<>(elem, listSupplier, x->true);  
   }  
     
   /**  
    * Create an empty list.  
    * @param <T> element type  
    * @return an empty list  
    */  
   public static <T> LazyConsList<T> nil() {  
     return new Nil<>();  
   }  
   
   /**  
    * Utility method that recurses through the specified list and its  
    * tails until it finds the first one with a head matching the filter.  
    * @param <T> element type  
    * @param list list that is filtered  
    * @param filter element filter  
    * @return the specified list or one of its tails  
    */  
   protected static <T> FilterableConsList<T> applyFilter(FilterableConsList<T> list, Predicate<T> filter) {  
     FilterableConsList<T> filtered = list;  
     while (!filter.test(filtered.head())) {  
       filtered = filtered.tail();  
     }  
     return filtered;  
   }  
   
   
   // ------------- Constructor  
   
   protected LazyConsList() {  
   }  
     
   // ------------- Collection interface  
     
   @Override  
   public final Iterator<T> iterator() {  
     return new Iterator<T>() {  
       private ConsList<T> current = LazyConsList.this;  
         
       @Override  
       public boolean hasNext() {  
         return !current.isNil();  
       }  
   
       @Override  
       public T next() {  
         T head = current.head();  
         current = current.tail();  
         return head;  
       }  
     };  
   }  
   
   @Override  
   public final int size() {  
     throw new UnsupportedOperationException("the size of a lazy list cannot be determined");  
   }  
     
   @Override  
   public Spliterator<T> spliterator() {  
     return Spliterators.spliteratorUnknownSize(iterator(), 0);  
   }  
     
   // ------------- concrete subclasses  
     
   /**  
    * A non-empty list.  
    * @param <T> element type  
    */  
   private static class Cons<T> extends LazyConsList<T> {  
   
     private final T head;  
     private final FilterableConsList<T> tail;  
     private final Predicate<T> filter;  
       
     public Cons(T head, FilterableConsList<T> tail, Predicate<T> filter) {  
       assert filter.test(head);  
       this.head = head;  
       this.tail = tail;  
       this.filter = filter;  
     }  
   
     @Override  
     public T head() {  
       return head;  
     }  
   
     @Override  
     public FilterableConsList<T> tail() {  
       return tail;  
     }  
   
      @Override  
     public boolean isNil() {  
       return false;  
     }  
       
     @Override  
     public FilterableConsList<T> filter(Predicate<T> f) {  
       Predicate<T> newFilter = filter.and(f);  
       FilterableConsList<T> filtered = applyFilter(this, newFilter);  
       return new Cons(filtered.head(), filtered.tail(), newFilter);  
     }  
   }  
   
   /**  
    * A non-empty list based on a generator function. The tail is lazily computed  
    * on demand and cached.  
    * @param <T> element type  
    */  
   private static class FunctionTail<T> extends LazyConsList<T> {  
   
     private final T head;  
     private final UnaryOperator<T> next;  
     private final Predicate<T> filter;  
     private FilterableConsList<T> tailCache;  
   
     public FunctionTail(T head, UnaryOperator<T> next, Predicate<T> filter) {  
       assert filter.test(head);  
       this.head = head;  
       this.next = next;  
       this.filter = filter;  
     }  
   
     @Override  
     public T head() {  
       return head;  
     }  
   
     @Override  
     public FilterableConsList<T> tail() {  
       // construct a new lazy list with the first element that passes the filter.  
       // use the generator function to construct the elements.  
       if (tailCache == null) {  
         T nextHead = head;  
         do {  
          nextHead = next.apply(nextHead);  
         } while (!filter.test(nextHead));  
         tailCache = new FunctionTail(nextHead, next, filter);  
       }  
       return tailCache;  
     }  
   
     @Override  
     public boolean isNil() {  
       return false;  
     }  
        
     @Override  
     public FilterableConsList<T> filter(Predicate<T> f) {  
       Predicate<T> newFilter = filter.and(f);  
       FilterableConsList<T> filtered = applyFilter(this, newFilter);  
       return new FunctionTail(filtered.head(), next, newFilter);  
     }  
   }  
   
   /**  
    * A non-empty list based on a supplied computation. The tail is lazily computed  
    * on demand and cached.  
    * @param <T> element type  
   */  
   private static class SuppliedTail<T> extends LazyConsList<T> {  
   
     private final T head;  
     private final Supplier<FilterableConsList<T>> supplier;  
     private final Predicate<T> filter;  
     private FilterableConsList<T> tailCache;  
   
     public SuppliedTail(T head, Supplier<FilterableConsList<T>> supplier, Predicate<T> filter) {  
       assert filter.test(head);  
       this.head = head;  
       this.supplier = supplier;  
       this.filter = filter;  
     }  
   
     @Override  
     public T head() {  
       return head;  
     }  
   
     @Override  
     public FilterableConsList<T> tail() {  
       // construct a new lazy list with the first element that passes the filter.  
       // delegate to the supplied function to create the tail of the current list.
       if (tailCache == null) {  
         tailCache = applyFilter(supplier.get(), filter);  
       }  
       return tailCache;  
     }  
   
     @Override  
     public boolean isNil() {  
       return false;  
     }  
   
     @Override  
     public FilterableConsList<T> filter(Predicate<T> f) {  
       Predicate<T> newFilter = filter.and(f);  
       FilterableConsList<T> filtered = applyFilter(this, newFilter);  
       return new Cons(filtered.head(), filtered.tail(), newFilter);  
     }  
   }  
     
   /**  
    * An empty list. Trying to access components of this class will result in   
    * an EmptyListException at runtime.  
    * @param <T> element type  
    */  
   private static class Nil<T> extends LazyConsList<T> {  
   
     @Override  
     public T head() {  
       throw new EmptyListException();  
     }  
   
     @Override  
     public FilterableConsList<T> tail() {  
       throw new EmptyListException();  
     }  
   
     @Override  
     public boolean isNil() {  
       return true;  
     }  
   
     @Override  
     public FilterableConsList<T> filter(Predicate<T> f) {  
       return this;  
     }  
      
   }  
 }  

Higher-Order Functions in Java 8 (Part 1)

This post will show how the first few examples (pages 4 - 7) from John Hughes' 1984 paper Why Functional Programming Matters ([WHYFP]) may be implemented in Java 8. The paper has been revised several times, here's a link to the 1990 version. It's a useful starting point to get acquainted with functional programming idioms. What's more, we'll reconstruct part of the new Java Collection (or Streams) API, namely the map/reduce operations.

What this post is not:
I'd like to make one very technical point: Often, lambda expressions in Java are presented as syntactic sugar for (anonymous) inner classes. They're not. They are translated to byte code differently and make use of new JVM instructions. If you would like to know more about this, read Brian Goetze's article Translation of Lambda Expressions and consult the Java API docs, perhaps starting with LambdaMetaFactory

The code examples have been compiled and tested with Java Lambda b88 and the Netbeans IDE Development version. (Build jdk8lambda-1731-on-20130523). They may or may not work with the latest builds, which can be downloaded from
Set up your environment as follows:
  1. Download and unzip NetBeans and Java 8
  2. Run NetBeans with command line parameter --jdkhome <path to jdk8>
Before we can get started, we need to define the recursive list data structure with which we're going to work: a list is something that is either nil (empty) or a cons-cell consisting of a head and a tail which is itself a list. Here's the interface:
 
   public interface ConsList<T> {  
   /**  
    * Returns the first element of this list.   
    *  
    * @return head  
    * @throws EmptyListException if the list is empty  
    */  
   T head();  

   /**  
    * Returns the tail of this list.  
    *  
    * @return tail  
    * @throws EmptyListException if the list is empty  
    */  
   ConsList<T> tail();  

   /**  
    * Tests if this list is nil. Nil is the empty list.  
    *  
    * @return <code>true</code> if the list is empty, otherwise <code>false</code>  
    */  
   boolean isNil();   
 }  
The implementation class (here called SimpleConsList) is not shown. It should provide some static factory methods (signatures see below) and a useful toString method. Implementation guidelines can be found here. (An enhanced version, which you can use with only minimal adjustments to some code shown later, is included in the follow-up to this post.)
   
   /**  
    * Create a list containing the given elements.  
    *  
    * @param <T> element type  
    * @param elems elements  
    * @return a new list  
    */ 
   public static <T> SimpleConsList<T> newList(T... elems);  

   /**  
    * Add an element at the front of a list.  
    * @param <T> element type  
    * @param elem new element  
    * @param list list to extend  
    * @return a new list with elem as its head and list as its tail  
    */  
   public static <T> SimpleConsList<T> cons(T elem, ConsList<T> list);  

   /**  
    * Create an empty list.  
    * @param <T> element type  
    * @return an empty list  
    */  
   public static <T> SimpleConsList<T> nil();  

Here's the motivating example taken from [WHYFP]: Suppose we wanted to add all the elements of a list of integers. Consider the following code:
 
   public static Integer sum(ConsList<Integer> list) {  
     if (list.isNil()) {  
       return 0;  
     } else {  
       return list.head() + sum(list.tail());  
     }  
   } 

What's really Integer specific here? Only the type, the neutral element, and the addition operator. This means that the computation of a sum can be modularized into a general recursive pattern and the specific parts. This recursive pattern is conventionally called foldr. Here's it:
   public static <T, R> R foldr(BiFunction<T, R, R> f, R x, ConsList<T> list) {  
     if (list.isNil()) {  
       return x;  
     }  
     return f.apply(list.head(), foldr(f, x, list.tail()));  
   }   
BiFunction is a functional interface from the package java.util.function. It is the type of functions taking two arguments. Now we may print out the sum of a list of integers like this:
  
  ConsList<Integer> ints = SimpleConsList.newList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
  System.out.println("Sum = " + foldr((x, y) -> x + y, 0, ints)); 
foldr is called a higher-order function, because it takes another function as its argument. (This is sometimes expressed from a programming perspective as having "code as data".)

There is also a closely related way to traverse a list, called foldl, which does the function application from left to right:
   public static <T, R> R foldl(BiFunction<T, R, R> f, R x, ConsList<T> list) {  
     if (list.isNil()) {  
       return x;  
     }  
     return foldl(f, f.apply(list.head(), x), list.tail());  
   }   

In fact a variant of foldl has been built into the Java Streams API. It's called reduce, because it may be used to "reduce" all the elements of a list to a single value (e. g. their sum). The naming is unfortunate, in my opinion, because as we shall see it relates only  to a special (if typical) case of what can be done with it, but is usual in functional programming. Here's corresponding code using that built-in method on a java.util.List
 List<Integer> intList = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);  
 System.out.println("Sum = " + intList.stream().reduce(0, (x, y) -> x + y));  

There is no correspondence to foldr  in the Java API. Perhaps because foldl is tail recursive. However, reduce also differs in important ways from foldl: It requires the folding function to be associative (as in our example), and it is in fact not guaranteed to process the stream elements in a left-to-right order.

The following code demonstrates how foldr (and its relative foldl) can be used to write many other functions on lists with no more programming effort. Here's how it works (cf. [WHYFP]):
  • length increments 0 as many times as there are cons'es
  • copy  cons'es the list elements onto the front of an empty list from right to left
  • reverse is similar, except it uses foldl to do the cons-ing from left to right. Just for fun, the lambda expression has been replaced with a method reference.
  • append cons'es the elements of chars1 onto the front of chars2 
     ConsList<Integer> ints = SimpleConsList.newList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);  
     ConsList<Boolean> bools = SimpleConsList.newList(true, true, false);  
     ConsList<Character> chars1 = SimpleConsList.newList('a', 'b', 'c');  
     ConsList<Character> chars2 = SimpleConsList.newList('x', 'y', 'z'); 
 
     System.out.println("Product = " + foldr((x, y) -> x * y, 1, ints));  
     System.out.println("One true = " + foldr((x, y) -> x || y, false, bools));  
     System.out.println("All true = " + foldr((x, y) -> x && y, true, bools));  
     System.out.println("Length = " + foldr((Boolean x, Integer y) -> y + 1, 0, bools));  
     System.out.println("Copy = " + foldr((x, y) -> cons(x, y), nil(), chars1));  
     System.out.println("Reverse = " + foldl(SimpleConsList::cons, nil(), chars1));  
     System.out.println("Append = " + foldr((x, y) -> cons(x, y), chars2, chars1));  

There's one more higher-order function I'd like to mention, namely map. This function takes a function argument f and applies f to all elements of a list, yielding a new list. For example, we may use it to double every element in a list:
   System.out.println("Double all = " + map(x -> 2 * x, ints));
The map-function is also part of the Java Streams API. Following the derivation in [WHYFP] we may reconstruct it like this:
  // map f = foldr (cons . f ) nil   
  public static <T, R> ConsList<R> map(Function<T, R> f, ConsList<T> list) {  
    Function<R, Function<ConsList<R>, ConsList<R>>> cons = curry((x, y) -> cons(x, y));  
    return foldr((T x, ConsList<R> y) -> cons.compose(f).apply(x).apply(y), nil(), list);  
  }  
Here we observe several things: map makes use of functional composition, a standard operator which is built into the interface java.util.function.Function. The expression cons.compose(f) returns a new function that first applies f to its argument and then applies cons to the result of that application. (Download the JDK source code and study how compose is implemented.) 

In order to apply functional composition, we must first "curry" the binary function. Currying transforms a single function of n arguments into n functions with a single argument each. Of course, we could just have defined cons in the method above as x -> y -> cons(x, y), but it's perhaps instructive to see a method that curries a BiFunction (side note: unfortunately, I see no general way to curry an n-ary function for arbitrary n in Java, the type system is just not flexible enough):
public static <T, U, R> Function<T, Function<U, R>> curry(BiFunction<T, U, R> f) {  
    return (T t) -> (U u) -> f.apply(t, u);  
}  
Note that no actual function invocation takes place, i. e. apply is not invoked until both function arguments have been supplied. After all this effort, although the coding may look intimidating, map is just like our simple copy-list example above, except that the list elements are not just copied, but first transformed through f.

Now we have all the pieces in hand to sum the elements of a matrix, where a matrix (implementation not shown) is represented as a list of lists. The function sumList adds up all the rows, and then the leftmost application of that function adds up the row totals to get the sum of the whole matrix.
 Matrix<Integer> matrix = new Matrix(ints, ints, ints);  
 Function<ConsList<Integer>, Integer> sumList = x -> foldr((a, b) -> a + b, 0, x);  
 System.out.println("Matrix sum = " + sumList.compose((Matrix<Integer> m) -> map(x -> sumList.apply(x), m)).apply(matrix));  

In my opinion, this is not too verbose, but they certainly have not turned Java into Scala or Haskell. In Scala the same expression would read something like val summatrix: List[List[Int]] => Int = sum.compose(map(sum))

In fact, by partially evaluating the above expression with respect to functional composition and application to the matrix argument, we can manually derive the following equivalent expression, which is somewhat easier to understand: sumList.apply(map(x -> sumList.apply(x), matrix))

Mark Mahieu has written on the topic of partial evaluation in Java.

In this post you have seen how Java 8 makes it possible to pass functions around not only to apply them to data but also to manipulate them and combine them into other functions. You have also seen how the map-/reduce operations from the Java 8 Streams API may be understood in these terms. This has been textbook stuff. In another post, I'll discuss the more exciting (I hope) topic of how lambda expressions are especially useful when combined with lazy evaluation.

-- Sebastian

Note: If you're curious how to use higher-order functions without language support for functional expressions, there are Java 7 (no lambda) versions of foldr, foldl, and map on Adrian Walker's blog.