A lightweight introduction to Recursion Schemes in Scala

A lightweight introduction to Recursion Schemes in Scala

Last December, I attended a talk by Zainab Ali entitled Topiary and the art of origami. This talk focused how to write a decision-tree-learning program in Scala. Zainab showed how, from some non-recursive functions, one can raise a decision tree data type, a learning algorithm, and a predicting algorithm, each in one line. She could do this by using Matryoshka, a library that implements recursion schemes in Scala.

Recursion schemes are relevant in functional programming. Contrary to what the jargon suggests, they are not too difficult to understand. Here is a small-steps introduction to basic recursion schemes in Scala.

Lists

A simple and commonplace recursive data type is the List. Scala provides a List[A] type, which is generic on the type A of elements. Since recursion schemes work just as well with non-generic collections, for convenience I am using a numbers-only list type instead:

sealed trait List
case object Nil extends List
case class Cons(head: Int, tail: List) extends List

This List type is recursive because the Cons class has a value parameter which is a List itself.

Step 1: list fold

How do you add up all the numbers in a list? Since most functional languages have no while loops, one has to use recursion:

def sum(list: List): Int =
  list match {
    case Nil              => 0
    case Cons(head, tail) => head + sum(tail)
  }

However, recursion should not be abused: it is too powerful and confusing, and can make programs difficult to read. This is why functional programming aims at hiding it behind higher-order functions like map or fold. Here is a fold for our List type:

def fold( zero: Int, op: (Int, Int) => Int)(list: List): Int =
  list match {
    case Nil              => zero
    case Cons(head, tail) => op( head, fold(zero, op)(tail) )
  }

Apart from the list, fold takes as parameters the result for Nil, and a function op used to combine the head of a Cons with the result of the recursive fold on the tail. We can now write the sum as a special case of fold, like this:

def add(a: Int, b: Int) = a + b
def sum(list) = fold(0, add)(list)

Now there is no recursion in the code of sum: it has been moved out of sum and into fold.

Step 2: list unfold

We can write a digits function, to get the list of digits ([0-9]) of a number, also using recursion:

def digits(seed: Int): List =
  if (seed == 0)
    Nil
  else
    Cons(seed % 10, digits(seed / 10) )

Unlike in sum, in digits we use recursion to build a list. Nevertheless, like sum with fold, we can also take recursion out of digits and put into a higher-order function called unfold:

def unfold(
  isEnd: Int => Boolean, op: Int => (Int, Int) )( seed: Int
): List =
  if (isEnd(seed))
    Nil
  else {
    val (head, next) = op(seed)
    Cons(head, unfold(isEnd, op)(next) )
  }

def isZero(x: Int): Boolean = x == 0
def div10(x: Int): (Int, Int) = (x % 10, x / 10)
def digits(seed) = unfold(isZero, div10)(seed)

Apart from the seed, unfold takes as parameter a predicate to mark the end, and a function Int => (Int, Int) to split a seed into the value at the head, and the seed from which to build the tail.

Step 3: The optional cons type

fold and unfold are opposed: this grows a list from a seed, that reduces it to a result. However, there is a sort of symmetry between their definitions. To see this symmetry, we use an OCons type alias:

type OCons = Option[(Int, Int)]

This OCons alias can describe an optional Cons, intuitively, if there is a recursive case (Some(Int, Int)), or a base case (None). We can write fold and unfold to work around this type:

def fold(out: OCons => Int)(list: List): Int =
  list match {
    case Nil              => out( None)
    case Cons(head, tail) => out( Some( head, fold(out)(tail) ))
  }

def unfold( into: Int => OCons)(seed: Int): List =
  into(seed) match {
    case None               => Nil
    case Some((head, next)) => Cons(head, unfold(into)(next) )
  }

Both fold and unfold now take as parameter a function on OCons, but with reversed types: in fold the function gets a number out of an OCons, whereas in unfold it splits a number into an OCons.

The new types of fold and unfold show another similarity:

fold:   (OCons => Int) => (List => Int)
unfold: (Int => OCons) => (Int => List)

So, we can see fold and unfold each as lifting a single-step functions, out or in, into a loop that performs that step many times.

Step 4: Tree Folds

Another recursive data type is the type of binary trees of numbers:

sealed trait Tree
case object Leaf extends Tree
case class Node(left: Tree, top: Int, right: Tree)

Unlike lists, the Tree type appears recursively twice in the Node subclass, for the left and right subtrees.

As we did for lists, we can write a sum function for trees, using recursion:

def sum(tree: Tree): Int =
  tree match {
    case Leaf            => 0
    case Node(ll,top,rr) => sum(ll) + top + sum(rr)
  }

Unlike the sum for lists, here we have two recursive calls, one per subtree. Despite this, we can also take the recursion out of sum and into a tree-fold function:

def fold( zero: Int, op: (Int, Int, Int) => Int)(tree: Tree): Int =
  tree match {
    case Leaf =>
      zero
    case Node(ll,top,rr) =>
      op( fold(zero, op)(ll), top, fold(zero, op)(rr) )
  }

def add3(a: Int, b: Int, c: Int) = a + b + c
def sum(tree: Tree): Int = fold(0, add3)(tree)

Step 5: Tree unfolds

We can write a function digits to lay the digits of a large number out in a binary tree, much like the digits function from Step 2.

def digits(seed: Int): Tree =
  if (seed ==0) Leaf
  else {
    val (pref, mid, suff) = splitNumber(seed)
    Node( digits(pref), mid, digits(suff) )
  }

// splitNumbers: split a number's digits in the middle,
// for example,  splitNumber(56784197) = (567, 8, 4197)
def splitNumber(seed: Int): (Int, Int, Int) = /*---*/

As in the case of a list, we can extract recursion from the digits function into an unfold functions for trees:

def unfold(
  isEnd: Int => Boolean, op: Int => (Int, Int, Int) )(seed: Int
): Tree =
  if (isEnd(seed))
    Leaf
  else {
    val (ll, top, rr) = op(seed)
    Node( unfold(isEnd, op)(ll), top, unfold(isEnd, op)(rr) )
  }

def digits(seed: Int): Tree = unfold(isEnd, splitNumber)(seed)

Step 6: The optional node type for

Like the fold and unfold for lists, the fold and unfold for trees are opposed but similar. As we did with OCons, we can show the symmetry between the fold and unfold functions for trees using an ONode type:

type ONode = Option[(Int, Int, Int)]

def fold(out: ONode => Int)(tree: Tree): Int =
  tree match {
    case Leaf =>
      fun( None )
    case Node(ll, top, rr) =>
      fun( Some( fold(out)(ll),  top, fold(out)(rr) ) )
  }

def unfold(in: Int => ONode)(seed: Int): Tree =
  fun(seed) match {
    case None =>
      Leaf
    case Some( (ll, top, rr) ) =>
      Node( unfold(out)(ll), top, unfold(out)(rr) )
  }

Here, ONode intuition is if there is recursion or not. It has three instead of two numbers, because in a tree there are two recursive occurrences, not one.

Step 7: Using maps

Let us compare the fold-unfold functions for lists based on the OCons alias, with those for trees that use the ONode alias. We can see some similarities between them:

  • The fold functions for lists and trees both take a function to get the number out of the OCons or ONode, respectively.
  • The unfold functions both take a function to split a number into an OCons or ONode.
  • There is a one-to-one translation between the cases of the data structure, and those of the non-recursive type alias.
  • A same recursive call, fold(out) or unfold(into), is applied to each recursive occurrence of the data type.

To highlight this similarity, we make OCons and ONode generic on a data type R, to mark the recursive positions.

type OCons[R] = Option[(Int, R)]
type ONode[R] = Option[(R, Int, R)]

At each type OCons or ONode, we can “apply a same recursive call to each recursive position” by using a map function:

def mapOC[A,B]( fun: A => B, ocons: OCons[A]): OCons[B] =
  ocons match {
    case None             => None
    case Some(head, tail) => Some(head, fun(tail))
  }

def mapON(fun: A => B, onode: ONode[A]): ONode[B] =
  onode match {
    case None              => None
    case Some(ll, top, rr) => (fun(ll), top, fun(rr))
  }

These are very similar to the map function for the Option type.

Step 8: Aligning Folds

The idea of a one-to-one translation from the cases of each data type to the cases of the Option alias, we can write with an open function for each type:

def open(tree: List): ONode[List] =
  tree match {
    case Nil              => None
    case Cons(head, tail) => Some(head, tail)
  }

def open(tree: Tree): ONode[Tree] =
  tree match {
    case Leaf              => None
    case Node(ll, top, rr) => Some(ll, top, rr)
  }

Using the functions open and map functions, we can now express the fold and unfold functions into a single line each:

def fold(out: OCons[Int] => Int)(List: Tree): Int =
  out( mapOC(fold(out), open(tree)) )

def fold(out: ONode[Int] => Int)(tree: Tree): Int =
  out( mapON(fold(out), open(tree)) )

These definitions of fold use a same map function to recursively apply fold or unfold to each appearance of the R type parameter.

Step 9: Align Unfolds

For its part, the 1:1 translation from the cases of the alias type to the cases of the data type we can write with a function close, for lists and trees.

def close(ocons: OCons[List]): List =
  ocons match {
    case None             => Nil
    case Some(head, tail) => Cons(head, tail)
  }

def close(onode: ONode[Tree]): Tree =
  onode match {
    case None              => Leaf
    case Some(ll, top, rr) => Node(ll, top, rr)
  }

Using the functions open, close, and map, we can write fold/unfold with a single line each:

def unfold(into: Int => OCons[Int])(seed: Int): List =
  close( mapOC( unfold(into), into(seed) ) )

def unfold(into: Int => ONode[Int])(seed: Int): Tree =
  close( mapON( unfold(into), into(seed) ) )

Again, we use map to apply the recursive call to unfold at each appearance of the R type parameter.

Step 10: Indirect Recursion for each Data Type

So far, we have only focused on hiding direct recursion from functions. Can we do the same with data types? Yes, if we use a form of indirection.

In fold and unfold, the functions in and out functions represent one recursive step. Likewise, we can use the types OCons[R] and ONode[R] to represent one link in a list or tree. The type parameter R gives us the indirection that we need to cut recursion in each data type:

case class List_Ind(opt: OCons[List_Ind])
case class Tree_Ind(opt: ONode[Tree_Ind])

The List_Ind (resp. Tree_Ind) data type now consists of a single case class with a single parameter opt. The type of opt is the type operator OCons (or ONode) applied to the data type List_Ind (or Tree_Ind) itself. Thus, we have shifted recursion from values to types.

For the List_Ind and Tree_Ind classes, we can write the following fold and unfold functions:

def fold(out: OCons[Int] => Int)(list: List_Ind): Int =
  out( mapOC( fold(out), list.opt) )

def fold(out: ONode[Int] => Int)(tree: Tree_Ind): Int =
  out( mapON( fold(out), tree.opt) )

def unfold(into: Int => OCons[Int])(seed: Int): List_Ind =
  List_Ind( mapOC(unfold(into), into(seed) ))

def unfold(into: Int => ONode[Int])(seed: Int): Tree_Ind =
  Tree_Ind( mapON(unfold(into), into(seed) ))

The two fold functions (resp. unfold) are almost the same, and they only differ in the type operator (OCons vs ONode), and in the map function (mapOC vs mapON) for it. Note that, since List_Ind and Tree_Ind already use OCons and ONode, we no longer need the auxiliary functions open and close.

Step 11: Indirect recursion for all data types

Now we can unify List_Ind and Tree_Ind into a single Ind case class, which generalizes indirect recursion in the data types.

The classes List_Ind and Tree_Ind above only differ in the type operator (OCons or ONode) used in the type of the opt field. To unify them, we need to extract that type operator as a type parameter, which we call as Rec.

case class Ind[ Rec[_] ]( opt: Rec[ Ind[Rec] ] )

The Rec[_] symbol means that Rec generalizes type operators, like OCons[R] or ONode[R], which themselves take a type parameter R.

We can now turn List_Ind and Tree_Ind each into a special case of Ind:

type List_Ind = Ind[OCons]
type Tree_Ind = Ind[ONode]

Step 12: Unified fold and unfold.

Now, let us unify the fold and unfold functions for List_Ind and Tree_Ind, from Step 10, into a single pair of fold-unfold functions for the Ind[Rec[_]] type. Let us start from the code of this functions.

def fold( out: OCons[Int] => Int)(list: List_Ind): Int =
  out( mapOC( fold(out), list.opt) )

def fold( out: ONode[Int] => Int)(tree: Tree_Ind): Int =
  out( mapON( fold(out), tree.opt) )

def unfold( into: Int => OCons[Int])(seed: Int): List_Ind =
  List_Ind( mapOC(unfold(into), into(seed) ))

def unfold( into: Int => ONode[Int])(seed: Int): Tree_Ind =
  Tree_Ind( mapON(unfold(into), into(seed) ))

To join the two fold (or unfold) functions into a single one, we have to 1) add the generic parameter Rec[_] to the function, replace each appearance of OCons and ONode by Rec; and replace each appearance of List_Ind and Tree_Ind by Ind[Rec]. This yields the following:

def fold[ Rec[_] ](out: Rec[Int] => Int)(ind: Ind[Rec]): Int =
  out( map( fold(out), ind.opt) )

def unfold[ Rec[_] ](into: Int => Rec[Int])(seed: Int): Ind[Rec] =
  Ind( map(unfold(into), into(seed)) )

We also have to unify mapOC and mapON into a function map for Rec[_]. However, since Rec[_] is generic, the map must be provided as another parameter of the function. A common way to do so is to wrap the map inside a trait, usually called a Functor, which is also generic on Rec.

trait Functor[F[_]] {
  def map[A, B](fun: A => B, from: F[A]): F[B]
}

def fold[Rec[_]](
  ff: Functor[Rec], out: Rec[Int] => Int)(ind: Ind[F]
): Int =
  out( ff.map( fold(ff, out), ind.opt) )

def unfold[Rec[_]](
  ff: Functor[Rec], into: Int => Rec[Int])(seed: Int
): Ind =
  Ind( ff.map( unfold(into), into(seed)) )

Now we can use these functions as a fold and unfold for lists, trees, or any data structure we need.

End

After all of the previous steps, what we have is the following:

  • We joined two recursive data types, lists and trees, into a single the case class Ind[Rec[_]] for indirect recursion.
  • We joined two fold functions, for lists and trees, into a single fold function for the Ind typeclass.
  • We joined two unfold functions, for lists and trees, into a single unfold function for the Ind typeclass.

The class Ind[Rec[_]] is generic on the Rec[_] type constructor, which means that it can represent not just lists or binary trees, but also any directly recursive data type.

Jargon

The notions and intuitions we used above have some special names.

  • A generic type like Ind[Rec[_]], that takes another generic type Rec[_] as a type parameter, is called a higher-kinded type.
  • The trait Functor, used to provide a map function, is a fundamental type-class in cats.
  • The Ind[Rec[_]] case class, that captures the recursive application of Rec[_] type constructor, is called the fixpoint data type. In Matryoshka, it is called Fix.
  • The out function, of the form Rec[A] => A, is called an algebra.
  • The into function, of the form A => Rec[A], is called a coalgebra.
  • The fold function, which reiterates the algebra Rec[A] => A into a function Ind[Rec] => A, is called a catamorphism.
  • A function like unfold, which reiterates the coalgebra A => Rec[A] into a function A => Ind[Rec], is called anamorphism.

Some of this jargon can be traced back to the academic research on recursion schemes, and by convention it is used in libraries, like Matryoshka, that implement recursion schemes.

Additional reading

  • Gibbon’s Origami Programming introduces the main recursion schemes, apart from folds and unfolds, using diverse examples such as sorting algorithms.
  • Zainab Ali’s talk, which inspired this post, applies the techniques we used above to another data type, decision trees, which are like binary trees but with more detail in each node or leaf.
  • Rob Norris presentation, about how to store academic genealogy trees (another recursive data type) in a relational database (Sep. 2016).

Ensure the success of your project

47 Degrees can work with you to help manage the risks of technology evolution, develop a team of top-tier engaged developers, improve productivity, lower maintenance cost, increase hardware utilization, and improve product quality; all while using the best technologies.