SyntaxHighlighter

3 Jun 2015

Recursive Parallel Search with Shallow Backtracking

The implementation in the previous post had a serious disadvantage: We generated substitutions for all the letters, and then checked if those substitutions constituted a solution. However, in a more efficient approach some partial substitutions could be rejected out-of-hand as not leading to a solution. This is called "shallow backtracking". In this post I show how to examine the letters as they occur from right to left in the operands, and interleave the checking of arithmetic constraints with the generation of substitutions.

This solution combines the advantages of flatMap-based search in Java 8, parallelization, persistent collections, and recursion. This code that I will show is completely general for cryptarithmetic puzzles with two summands, because the constraints are no longer problem-specific: in fact they encode only the rules of long addition plus a general side condition.

The implementation is also a bit more idiomatic (for Java), with less clutter in the method interfaces, because we keep the read-only operands in instance variables instead of passing them around explicitly.

The really nice thing is that this approach is about 25 times faster than the original parallel flatMap solution with exhaustive search.

Here's the code:
import static com.googlecode.totallylazy.collections.PersistentList.constructors.list;
import static com.googlecode.totallylazy.collections.PersistentMap.constructors.map;
import static java.util.stream.Collectors.toList;

import java.util.Collection;
import java.util.Map;
import java.util.stream.Stream;

import com.googlecode.totallylazy.collections.PersistentList;
import com.googlecode.totallylazy.collections.PersistentMap;

public class SendMoreMoneyShallow {

    static final int PRUNED = -1;
    static final char PADDING = ' ';
    static final PersistentList<Integer> DIGITS = list(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
    
    // padded puzzle arguments: op1 + op2 = op3
    final String op1;
    final String op2;
    final String op3;
    
    public static void main(String[] args) {
        SendMoreMoneyShallow puzzle = new SendMoreMoneyShallow(" send", " more", "money");
        Collection<String> solutions = puzzle.solve();
        System.out.println("There are " + solutions.size() + " solutions: " + solutions);
    }

    public SendMoreMoneyShallow(String op1, String op2, String op3) {
        // the arguments come padded with blanks so they all have the same length
        // there is no need to reverse the strings, because we have random access and can process them backwards
        assert op1.length() == op3.length();
        assert op2.length() == op3.length();
        this.op1 = op1;
        this.op2 = op2;
        this.op3 = op3;
    }
    
    public Collection<String> solve() {
        PersistentMap<Character, Integer> subst = map();
        Collection<String> solutions = go(op1.length() - 1, subst, 0).collect(toList());
        return solutions;
    }

    Stream<String> go(int i, PersistentMap<Character, Integer> subst, int carry) {
        // Each level of recursion accomplishes up to three substitutions of a character with a number. The recursion
        // should end when we run out of characters to substitute. At this point, all constraints have already been
        // checked and therefore the substitutions must represent a solution.
        if (i < 0) {
            return solution(subst);
        }
        
        // the state consists of partial substitutions and the carry. Every time we have made a substitution for a column
        // of letters (from right to left), we immediately check constraints.
        Character sx = op1.charAt(i);
        Character sy = op2.charAt(i);
        Character sz = op3.charAt(i);
        return candidates(sx, subst).stream().parallel().flatMap(x -> {
                PersistentMap<Character, Integer> substX = subst.insert(sx,x);
                return candidates(sy, substX).stream().flatMap(y -> {
                    PersistentMap<Character, Integer> substXY = substX.insert(sy,y);
                    return candidates(sz, substXY).stream().flatMap(z ->   {
                        PersistentMap<Character, Integer> substXYZ = substXY.insert(sz, z);
                        // recurse if not pruned, using the tails of the strings, the substitutions we have just made, and
                        // the value for carry that results from checking the arithmetic constraints
                        int nextCarry = prune(substXYZ, carry, x, y, z);
                        return nextCarry == PRUNED ? Stream.empty() : go(i - 1, substXYZ, nextCarry);
                    });});});
    }

    int prune(PersistentMap<Character, Integer> subst, int carry, Integer x, Integer y, Integer z) {
        // neither of the most significant digits may be 0, and we cannot be sure the substitutions have already been made
        if (subst.getOrDefault(mostSignificantLetter(op1), 1) == 0 || subst.getOrDefault(mostSignificantLetter(op2), 1) == 0) {
            return PRUNED;
        }

        // the column sum must be correct
        int zPrime = x + y + carry;
        if (zPrime % 10 != z) {
            return PRUNED;
        }

        // return next carry
        return zPrime / 10;
    }

    PersistentList<Integer> candidates(Character letter, PersistentMap<Character, Integer> subst) {
        if (letter == PADDING) {
            return list(0);
        }
        // if we have a substitution, use that, otherwise consider only those digits that have not yet been assigned
        return subst.containsKey(letter) ? list(subst.get(letter)) : DIGITS.deleteAll(subst.values());
    }

    Stream<String> solution(PersistentMap<Character, Integer> subst) {
        // transform the set of substitutions to a solution (in this case a String because Java has no tuples)
        int a = toNumber(subst, op1.trim());
        int b = toNumber(subst, op2.trim());
        int c = toNumber(subst, op3.trim());
        return Stream.of("(" + a + "," + b + "," + c + ")");
    }

    static final int toNumber(Map<Character, Integer> subst, String word) {
        // return the integer corresponding to the given word according to the substitutions
        assert word.length() > 0;
        return word.chars().map(x -> subst.get((char)x)).reduce((x, y) -> 10 * x + y).getAsInt();
    }

    static char mostSignificantLetter(String op) {
        return op.trim().charAt(0);
    }
}

And here's a representative performance measurement:

# JMH 1.9.1 (released 40 days ago)
# VM invoker: C:\Program Files\Java\jdk1.8.0_25\jre\bin\java.exe
# VM options: -Dfile.encoding=UTF-8
# Warmup: 5 iterations, 1 s each
# Measurement: 25 iterations, 1 s each
# Timeout: 10 min per iteration
# Threads: 1 thread, will synchronize iterations
# Benchmark mode: Average time, time/op
# Benchmark: java8.streams.SendMoreMoneyShallowBenchmark.measureShallowBacktrackingPerformance


Benchmark                              Mode  Cnt  Score   Error  Units
measureShallowBacktrackingPerformance  avgt   25  6.319 ± 0.082  ms/op


There is room for still more improvement, because the substitution for "z" in each round is in fact determined by the previous substitutions and need not be guessed.