Scala Functional Programming Patterns
上QQ阅读APP看书,第一时间看更新

The problem – grouping continuous integers

A while back, I wanted a simple and elegant way to solve the following problem:

Given: A sorted list of numbers

When: We group these numbers

Then: Each number in a group is higher by one than its predecessor. So, let's be good and write a few tests first:

 @Test
 public void testFourGroup() {
  List<Integer> list = Lists.newArrayList(1, 2, 3, 4, 5, 9, 11, 20, 21, 22);
  List<List<Integer>> groups = groupThem.groupThem(list);
  assertThat(groups.size(), equalTo(4));
  assertThat(groups.get(0), contains(1, 2, 3, 4, 5));
  assertThat(groups.get(1), contains(9));
  assertThat(groups.get(2), contains(11));
  assertThat(groups.get(3), contains(20, 21, 22));
 }
 @Test
 public void testNoGroup() {
  List<Integer> emptyList = Lists.newArrayList();
  List<List<Integer>> groups = groupThem.groupThem(emptyList,
    new MyPredicate());
  assertThat(groups, emptyIterable());
}
 @Test
 public void testOnlyOneGroup() {
  List<Integer> list = Lists.newArrayList(1);
  List<List<Integer>> groups = groupThem.groupThem(list,new MyPredicate());
  assertThat(groups.size(), equalTo(1));
  assertThat(groups.get(0), contains(1));
 }

We will make use of the excellent Hamcrest matchers (and the excellent Guava library) to help us express succinctly what we want our code to do.

Java code

You know the drill! Let's roll up our sleeves and dish out some code. The following looks pretty good:

 public List<List<Integer>> groupThem(final List<Integer> list) {
  final List<Integer> inputList = Colletions.unmodifiableList(list);
  final List<List<Integer>> result = Lists.newArrayList();
  int i = 0;
   while( i < inputlist.size()){
   i = pickUpNextGroup(i, inputList, result); // i must progress
  }
 return result;
 }

 private int pickUpNextGroup(final int start, final List<Integer> inputList,
   final List<List<Integer>> result) {
  Validate.isTrue(!inputList.isEmpty(),
    "Input list should have at least one element");
  Validate.isTrue(start <= inputList.size(), "Invalid start index");

  final List<Integer> group = Lists.newArrayList();
  
  int currElem = inputList.get(start );
  group.add(currElem); // We will have at least one element in the group

  int next = start + 1; // next index may be out of range

  while (next < inputList.size()) {
   final int nextElem = inputList.get(next); // next is in range
   if (nextElem - currElem == 1) { // grouping condition
    group.add(nextElem);
    currElem = nextElem; // setup for next iteration
   } else {
    break; // this group is done
   }
   ++next;
  }
  result.add(group); // add the group to result list
  Validate.isTrue(next > start); // make sure we keep moving
  return next; // next index to start iterating from
                // could be past the last valid index
 }

This code has a lot of subtlety. We use the Apache commons 3 validation API for asserting the invariants. Download this book's source code to check out the unit test.

We are using the excellent Guava library to work with lists. Note the ease of this library, with which we can create and populate the list in one line. Refer to https://code.google.com/p/guava-libraries/wiki/CollectionUtilitiesExplained for more details.

We are careful to first make the input list an unmodifiable list—to protect ourselves from stepping on our toes... We also need to carefully arrange the iteration to make sure that it eventually terminates and we do not accidentally access a non-existent list index (try to say it without gulping).

Going scalaish

Remember the sermon on being abstract? Our Java code is dealing with a lot of higher level and lower level details—all at the same time... We really don't want to do that; we wish to selectively ignore the details... We really don't want to deal with all the corner cases of iteration—the Java code is worried stiff about looking at two consecutive elements and the corner cases.

What if this was somehow handled for us so we can focus on the business at hand? Let's take a look at the following example:

def groupNumbers(list: List[Int]) = {
  def groupThem(lst: List[Int], acc: List[Int]): List[List[Int]] = lst match {
    case Nil => acc.reverse :: Nil
    case x :: xs =>
      acc match {
        case Nil => groupThem(xs, x :: acc)
        case y :: ys if (x - y == 1) =>
          
            
         
      case _ =>
            acc.reverse :: groupThem(xs, x :: List())
      }
  }
  groupThem(list, List())
}
groupNumbers(x)  

Thinking recursively...

This version looks a lot better—doesn't it? We use features such as nested functions and recursion and get stuff such as immutability for free... A list in Scala is immutable.In Java, we need to work a bit more to make things immutable. In Scala, it is the other way around. You need to go a little further to bring in mutability... We also use something called persistent data structures that have nothing to do with databases here. We will look at it in detail in a while...

However, a bit better version follows—meaning a tail-recursive version. Take a look at the following example of it:

def groupNumbers(list: List[Int])(f: (Int, Int) => Boolean) : List[List[Int]] = {
  @tailrec
  def groupThem(lst: List[Int], result: List[List[Int]], acc: List[Int]): List[List[Int]] = lst match {
    case Nil => acc.reverse :: result
    case x :: xs =>
      acc match {
        case Nil => groupThem(xs, result, x :: acc)
        case y :: ys if (x - y == 1) =>
          
            groupThem(xs, result, x :: acc)
        
      case _ => 
            groupThem(xs, acc.reverse :: result, x :: List())
      }
  }
  val r = groupThem(list, List(), List())
  r.reverse
}

The @tailrec function does some good to our code. This is an annotation that we put on a method. We use it to make sure that the method will be compiled with tail call optimization (TCO). TCO converts the recursive form into a loop.

If the Scala compiler cannot apply TCO, it flags an error. Recursion and TCO are covered in detail in Chapter 3, Recursion and Chasing Your Own Tail.

Now, having set the stage—let's look at the reusability feature...

Reusability – the commonality/variability analysis

Instead of a fixed grouping criteria, could we make the Java version more dynamic? Instead of the expressing the condition directly in the code, we could wrap up the condition as a predicate:

if (nextElem - currElem == 1) { // grouping condition 

A predicate is a function that returns a Boolean value. It should take two numbers and tell whether the predicate is true or false?

Now there could be multiple predicates. For example, we might want to find elements that differ by 2 or 3 numbers. We can define an interface as follows:

public interface Predicate {
        boolean apply(int nextElem, int currElem);
}

public class MyPredicate implements Predicate {

    public boolean apply(int nextElem, int currElem) {
        return nextElem - currElem == 1;
    }
}

Why would we do this? We would do this for two reasons:

  • We would want to reuse the algorithm correctly and pick up two consecutive elements. This algorithm would handle all corner cases as in the previous cases.
  • The actual grouping criteria in our case is nextElem – currElem == 1, but we can reuse the previous code to regroup the numbers using another criteria.

Here is how grouping would look; I am just showing the changed code:

public List<List<Integer>> groupThem(List<Integer> list, Predicate myPredicate) {
    ...
        while (int i = 0; i < inputList.size();) {
            i = pickUpNextGroup(i, inputList, myPredicate, result); // i must
            // progress
            ...
        }
}

private int pickUpNextGroup(int start, List<Integer> inputList, Predicate myPredicate,
        List<List<Integer>> result) {
    ...
    final int nextElem = inputList.get(next); // next is in range
    // if (nextElem - currElem == 1) { // grouping condition
    if (myPredicate.apply(nextElem, currElem)) { // grouping condition
        group.add(nextElem);
    ...
    }
}

It is some work to define and use the predicate. Scala makes expressing this variability pretty easy, as we can use functions. Just pass on your criteria as a function and you are good to go:

def groupNumbers(list: List[Int])(f: (Int, Int) => Boolean) : // 1 = {
 def groupThem(lst: List[Int], acc: List[Int]): List[List[Int]] = lst match {
 case Nil => acc.reverse :: Nil
 case x :: xs =>
 acc match {
 case Nil => groupThem(xs, x :: acc)
 case y :: ys if (f(y, x)) =>
 // 2
 groupThem(xs, x :: acc)
 
 case _ =>
 acc.reverse :: groupThem(xs, x :: List())
 }
 }
 groupThem(list, List())
}
val p = groupNumbers(x) _ // 3
p((x, y) => y - x == 1) // 4
p((x, y) => y - x == 2) // 5

In 1, we use the function - (f: (Int, Int) => Boolean).

Note

The f: (Int, Int) => Boolean syntax means that f is a function parameter. The function takes two Int params and returns Boolean.

Here is a quick REPL session to understand this better:

scala> val x : (Int, Int) => Boolean = (x: Int, y: Int) => x > y
x: (Int, Int) => Boolean = <function2>
scala> x(3, 4)
res0: Boolean = false
scala> x(3, 2)
res1: Boolean = true
scala> :t x
(Int, Int) => Boolean

We are using the t feature of the REPL to look at the type of x.

This is a function—a bit of code we can pass around—that corresponds to the previous Java predicate snippet. It is a function that takes two integers and returns a Boolean.

(A function is in fact a class that mixes in a trait, function2 in our case).

At 2, we use it. Now comes the fun part.

At 3, we hold on to everything else except the function. And at 4 and 5, we just pass different functions. It would be more correct to say that we use function literals. lo and behold —we get the answers right.

This code also shows some currying and partially applied functions in action… We cover both of these features in Chapter 6, Currying Favors with Your code.

The one-liner shockers

Have you heard this phrase anytime: it is just a one-liner! You see all the work done in a single line of code, it seems almost magical. The Unix shell thrives on these one-liners. For example:

 seq 1 100 | paste -s -d '*' | bc

Phew! This single command line generates a sequence of numbers, 1 to 100, generates a multiplication expression from these, and feeds the same to bca calculator that does the actual multiplication.

Scala has a legion of these. Here is one applied to our problem by using scalaz library.

Note

To try the following snippet, you need to install the scalaz library. It is not part of the Scala standard library. Here are a few simple instructions to install it.

Create a directory of your choice and switch to it. Next, download and install the Simple Build Tool (sbt) from http://www.scala-sbt.org/download.html.

/Users/Atulkhot/TryScalaz> sbt
[info] Set current project to tryscalaz (in build file:/Users/Atulkhot/TryScalaz/)
> set scalaVersion := "2.11.7"
… // output elided
> set libraryDependencies += "org.scalaz" %% "scalaz-core" % "7.1.0"
… // output elided
> console

Now the following snippet should work after installing the scalaz library:

scala> import scalaz.syntax.std.list._
import scalaz.syntax.std.list._
scala> List(1, 2, 3, 4, 6, 8, 9) groupWhen ((x,y) => y - x == 1)
res3: List[scalaz.NonEmptyList[Int]] = List(NonEmptyList(1, 2, 3, 4), NonEmptyList(6), NonEmptyList(8, 9))

Using functions the solution is just a one-liner:

scala> List(1, 3, 4, 6, 8, 9) groupWhen ((x,y) => y - x == 2)
res4: List[scalaz.NonEmptyList[Int]] = List(NonEmptyList(1, 3), NonEmptyList(4, 6, 8), NonEmptyList(9))

So we are just not writing any of the supporting code—just stating our criteria for grouping.

You can read more about Scalaz at http://eed3si9n.com/learning-scalaz/.