A lightweight introduction to Recursion Schemes in Scala
by Diego Alonso Blas
- •
- February 28, 2018
- •
- functional programming• scala• matryoshka• recursion schemes
- |
- 28 minutes to read.

Last December, I attended a talk by Zainab Ali entitled Topiary and the art of origami. This talk focused how to write a decision-tree-learning program in Scala. Zainab showed how, from some non-recursive functions, one can raise a decision tree data type, a learning algorithm, and a predicting algorithm, each in one line. She could do this by using Matryoshka, a library that implements recursion schemes in Scala.
Recursion schemes are relevant in functional programming. Contrary to what the jargon suggests, they are not too difficult to understand. Here is a small-steps introduction to basic recursion schemes in Scala.
Lists
A simple and commonplace recursive data type is the List. Scala provides a List[A]
type, which is generic on the type A
of elements. Since recursion schemes work just as well with non-generic collections, for convenience I am using a numbers-only list type instead:
sealed trait List
case object Nil extends List
case class Cons(head: Int, tail: List) extends List
This List type is recursive because the Cons
class has a value parameter which is a List itself. Here is the full code of this blog post example (with minor changes).
Step 1: list fold
How do you add up all the numbers in a list? Since most functional languages have no while
loops, one has to use recursion:
def sum(list: List): Int =
list match {
case Nil => 0
case Cons(head, tail) => head + sum(tail)
}
However, recursion should not be abused: it is too powerful and confusing, and can make programs difficult to read. This is why functional programming aims at hiding it behind higher-order functions like map
or fold
. Here is a fold
for our List type:
def fold( zero: Int, op: (Int, Int) => Int)(list: List): Int =
list match {
case Nil => zero
case Cons(head, tail) => op( head, fold(zero, op)(tail) )
}
Apart from the list, fold
takes as parameters the result for Nil
, and a function op
used to combine the head of a Cons
with the result of the recursive fold
on the tail. We can now write the sum
as a special case of fold
, like this:
def add(a: Int, b: Int) = a + b
def sum(list) = fold(0, add)(list)
Now there is no recursion in the code of sum
: it has been moved out of sum
and into fold
.
Step 2: list unfold
We can write a digits
function, to get the list of digits ([0-9]
) of a number, also using recursion:
def digits(seed: Int): List =
if (seed == 0)
Nil
else
Cons(seed % 10, digits(seed / 10) )
Unlike in sum
, in digits
we use recursion to build a list. Nevertheless, like sum
with fold
, we can also take recursion out of digits
and put into a higher-order function called unfold
:
def unfold(
isEnd: Int => Boolean, op: Int => (Int, Int) )( seed: Int
): List =
if (isEnd(seed))
Nil
else {
val (head, next) = op(seed)
Cons(head, unfold(isEnd, op)(next) )
}
def isZero(x: Int): Boolean = x == 0
def div10(x: Int): (Int, Int) = (x % 10, x / 10)
def digits(seed) = unfold(isZero, div10)(seed)
Apart from the seed, unfold
takes as parameter a predicate to mark the end, and a function Int => (Int, Int)
to split a seed into the value at the head, and the seed from which to build the tail.
Step 3: The optional cons type
fold
and unfold
are opposed: this grows a list from a seed, that reduces it to a result. However, there is a sort of symmetry between their definitions. To see this symmetry, we use an OCons
type alias:
type OCons = Option[(Int, Int)]
This OCons
alias can describe an optional Cons
, intuitively, if there is a recursive case (Some(Int, Int)
), or a base case (None)
. We can write fold
and unfold
to work around this type:
def fold(out: OCons => Int)(list: List): Int =
list match {
case Nil => out( None)
case Cons(head, tail) => out( Some( (head, fold(out)(tail)) ))
}
def unfold( into: Int => OCons)(seed: Int): List =
into(seed) match {
case None => Nil
case Some((head, next)) => Cons(head, unfold(into)(next) )
}
Both fold
and unfold
now take as parameter a function on OCons
, but with reversed types: in fold
the function gets a number out of an OCons
, whereas in unfold
it splits a number into an OCons
.
The new types of fold
and unfold
show another similarity:
fold: (OCons => Int) => (List => Int)
unfold: (Int => OCons) => (Int => List)
So, we can see fold
and unfold
each as lifting a single-step functions, out
or in
, into a loop that performs that step many times.
Step 4: Tree Folds
Another recursive data type is the type of binary trees of numbers:
sealed trait Tree
case object Leaf extends Tree
case class Node(left: Tree, top: Int, right: Tree)
Unlike lists, the Tree
type appears recursively twice in the Node
subclass, for the left and right subtrees.
As we did for lists, we can write a sum
function for trees, using recursion:
def sum(tree: Tree): Int =
tree match {
case Leaf => 0
case Node(ll,top,rr) => sum(ll) + top + sum(rr)
}
Unlike the sum
for lists, here we have two recursive calls, one per subtree.
Despite this, we can also take the recursion out of sum
and into a tree-fold
function:
def fold( zero: Int, op: (Int, Int, Int) => Int)(tree: Tree): Int =
tree match {
case Leaf =>
zero
case Node(ll,top,rr) =>
op( fold(zero, op)(ll), top, fold(zero, op)(rr) )
}
def add3(a: Int, b: Int, c: Int) = a + b + c
def sum(tree: Tree): Int = fold(0, add3)(tree)
Step 5: Tree unfolds
We can write a function digits
to lay the digits of a large number out in a binary tree, much like the digits
function from Step 2.
def digits(seed: Int): Tree =
if (seed ==0) Leaf
else {
val (pref, mid, suff) = splitNumber(seed)
Node( digits(pref), mid, digits(suff) )
}
// splitNumbers: split a number's digits in the middle,
// for example, splitNumber(56784197) = (567, 8, 4197)
def splitNumber(seed: Int): (Int, Int, Int) = /*---*/
As in the case of a list, we can extract recursion from the digits
function into an unfold
functions for trees:
def unfold(
isEnd: Int => Boolean, op: Int => (Int, Int, Int) )(seed: Int
): Tree =
if (isEnd(seed))
Leaf
else {
val (ll, top, rr) = op(seed)
Node( unfold(isEnd, op)(ll), top, unfold(isEnd, op)(rr) )
}
def digits(seed: Int): Tree = unfold(isEnd, splitNumber)(seed)
Step 6: The optional node type for
Like the fold
and unfold
for lists, the fold
and unfold
for trees are opposed but similar. As we did with OCons
, we can show the symmetry between the fold
and unfold
functions for trees using an ONode
type:
type ONode = Option[(Int, Int, Int)]
def fold(out: ONode => Int)(tree: Tree): Int =
tree match {
case Leaf =>
out( None )
case Node(ll, top, rr) =>
out( Some( (fold(out)(ll), top, fold(out)(rr) ) ))
}
def unfold(in: Int => ONode)(seed: Int): Tree =
in(seed) match {
case None =>
Leaf
case Some( (ll, top, rr) ) =>
Node( unfold(in)(ll), top, unfold(in)(rr) )
}
Here, ONode
intuition is if there is recursion or not. It has three instead of two numbers, because in a tree there are two recursive occurrences, not one.
Step 7: Using maps
Let us compare the fold
-unfold
functions for lists based on the OCons
alias, with those for trees that use the ONode
alias. We can see some similarities between them:
- The
fold
functions for lists and trees both take a function to get the numberout
of theOCons
orONode
, respectively. - The
unfold
functions both take a function to split a numberinto
anOCons
orONode
. - There is a one-to-one translation between the cases of the data structure, and those of the non-recursive type alias.
- A same recursive call,
fold(out)
orunfold(into)
, is applied to each recursive occurrence of the data type.
To highlight this similarity, we make OCons
and ONode
generic on a data type R
, to mark the recursive positions.
type OCons[R] = Option[(Int, R)]
type ONode[R] = Option[(R, Int, R)]
At each type OCons
or ONode
, we can “apply a same recursive call to each recursive position” by using a map
function:
def mapOC[A, B]( fun: A => B, ocons: OCons[A]): OCons[B] =
ocons match {
case None => None
case Some((head, tail)) => Some((head, fun(tail)))
}
def mapON[A, B](fun: A => B, onode: ONode[A]): ONode[B] =
onode match {
case None => None
case Some((ll, top, rr)) => Some((fun(ll), top, fun(rr)))
}
These are very similar to the map
function for the Option
type.
Step 8: Aligning Folds
The idea of a one-to-one translation from the cases of each data type to the cases of the Option
alias, we can write with an open
function for each type:
def open(list: List): OCons[List] =
tree match {
case Nil => None
case Cons(head, tail) => Some((head, tail))
}
def open(tree: Tree): ONode[Tree] =
tree match {
case Leaf => None
case Node(ll, top, rr) => Some((ll, top, rr))
}
Using the functions open
and map
functions, we can now express the fold
and unfold
functions into a single line each:
def fold(out: OCons[Int] => Int)(list: List): Int =
out( mapOC(fold(out), open(tree)) )
def fold(out: ONode[Int] => Int)(tree: Tree): Int =
out( mapON(fold(out), open(tree)) )
These definitions of fold
use a same map function to recursively apply fold or unfold to each appearance of the R
type parameter.
Step 9: Align Unfolds
For its part, the 1:1 translation from the cases of the alias type to the cases of the data type we can write with a function close
, for lists and trees.
def close(ocons: OCons[List]): List =
ocons match {
case None => Nil
case Some((head, tail)) => Cons(head, tail)
}
def close(onode: ONode[Tree]): Tree =
onode match {
case None => Leaf
case Some((ll, top, rr)) => Node(ll, top, rr)
}
Using the functions open, close, and map, we can write fold
/unfold
with a single line each:
def unfold(into: Int => OCons[Int])(seed: Int): List =
close( mapOC( unfold(into), into(seed) ) )
def unfold(into: Int => ONode[Int])(seed: Int): Tree =
close( mapON( unfold(into), into(seed) ) )
Again, we use map to apply the recursive call to unfold at each appearance of the R
type parameter.
Step 10: Indirect Recursion for each Data Type
So far, we have only focused on hiding direct recursion from functions. Can we do the same with data types? Yes, if we use a form of indirection.
In fold
and unfold
, the functions in
and out
functions represent one recursive step. Likewise, we can use the types OCons[R]
and ONode[R]
to represent one link in a list or tree. The type parameter R
gives us the indirection that we need to cut recursion in each data type:
case class List_Ind(opt: OCons[List_Ind])
case class Tree_Ind(opt: ONode[Tree_Ind])
The List_Ind
(resp. Tree_Ind
) data type now consists of a single case class with a single parameter opt
. The type of opt
is the type operator OCons
(or ONode
) applied to the data type List_Ind
(or Tree_Ind
) itself. Thus, we have shifted recursion from values to types.
For the List_Ind
and Tree_Ind
classes, we can write the following fold
and unfold
functions:
def fold(out: OCons[Int] => Int)(list: List_Ind): Int =
out( mapOC( fold(out), list.opt) )
def fold(out: ONode[Int] => Int)(tree: Tree_Ind): Int =
out( mapON( fold(out), tree.opt) )
def unfold(into: Int => OCons[Int])(seed: Int): List_Ind =
List_Ind( mapOC(unfold(into), into(seed) ))
def unfold(into: Int => ONode[Int])(seed: Int): Tree_Ind =
Tree_Ind( mapON(unfold(into), into(seed) ))
The two fold
functions (resp. unfold
) are almost the same, and they only differ in the type operator (OCons
vs ONode
), and in the map function (mapOC
vs mapON
) for it.
Note that, since List_Ind
and Tree_Ind
already use OCons
and ONode
, we no longer need the auxiliary functions open
and close
.
Step 11: Indirect recursion for all data types
Now we can unify List_Ind
and Tree_Ind
into a single Ind
case class, which generalizes indirect recursion in the data types.
The classes List_Ind
and Tree_Ind
above only differ in the type operator (OCons
or ONode
) used in the type of the opt
field. To unify them, we need to extract that type operator as a type parameter, which we call as Rec
.
case class Ind[ Rec[_] ]( opt: Rec[ Ind[Rec] ] )
The Rec[_]
symbol means that Rec
generalizes type operators, like OCons[R]
or ONode[R]
, which themselves take a type parameter R
.
We can now turn List_Ind
and Tree_Ind
each into a special case of Ind
:
type List_Ind = Ind[OCons]
type Tree_Ind = Ind[ONode]
Step 12: Unified fold and unfold.
Now, let us unify the fold
and unfold
functions for List_Ind
and Tree_Ind
, from Step 10, into a single pair of fold
-unfold
functions for the Ind[Rec[_]]
type. Let us start from the code of this functions.
def fold( out: OCons[Int] => Int)(list: List_Ind): Int =
out( mapOC( fold(out), list.opt) )
def fold( out: ONode[Int] => Int)(tree: Tree_Ind): Int =
out( mapON( fold(out), tree.opt) )
def unfold( into: Int => OCons[Int])(seed: Int): List_Ind =
List_Ind( mapOC(unfold(into), into(seed) ))
def unfold( into: Int => ONode[Int])(seed: Int): Tree_Ind =
Tree_Ind( mapON(unfold(into), into(seed) ))
To join the two fold
(or unfold
) functions into a single one, we have to 1) add the generic parameter Rec[_]
to the function, replace each appearance of OCons
and ONode
by Rec
; and replace each appearance of List_Ind
and Tree_Ind
by Ind[Rec]
. This yields the following:
def fold[ Rec[_] ](out: Rec[Int] => Int)(ind: Ind[Rec]): Int =
out( map( fold(out), ind.opt) )
def unfold[ Rec[_] ](into: Int => Rec[Int])(seed: Int): Ind[Rec] =
Ind( map(unfold(into), into(seed)) )
We also have to unify mapOC
and mapON
into a function map
for Rec[_]
. However, since Rec[_]
is generic, the map
must be provided as another parameter of the function. A common way to do so is to wrap the map
inside a trait, usually called a Functor, which is also generic on Rec
.
trait Functor[F[_]] {
def map[A, B](fun: A => B, from: F[A]): F[B]
}
def fold[Rec[_]](
ff: Functor[Rec], out: Rec[Int] => Int)(ind: Ind[F]
): Int =
out( ff.map( fold(ff, out), ind.opt) )
def unfold[Rec[_]](
ff: Functor[Rec], into: Int => Rec[Int])(seed: Int
): Ind =
Ind( ff.map( unfold(into), into(seed)) )
Now we can use these functions as a fold
and unfold
for lists, trees, or any data structure we need.
End
After all of the previous steps, what we have is the following:
- We joined two recursive data types, lists and trees, into a single the case class
Ind[Rec[_]]
for indirect recursion. - We joined two
fold
functions, for lists and trees, into a singlefold
function for theInd
typeclass. - We joined two
unfold
functions, for lists and trees, into a singleunfold
function for theInd
typeclass.
The class Ind[Rec[_]]
is generic on the Rec[_]
type constructor, which means that it can represent not just lists or binary trees, but also any directly recursive data type.
Jargon
The notions and intuitions we used above have some special names.
- A generic type like
Ind[Rec[_]]
, that takes another generic typeRec[_]
as a type parameter, is called a higher-kinded type. - The trait
Functor
, used to provide amap
function, is a fundamental type-class incats
. - The
Ind[Rec[_]]
case class, that captures the recursive application ofRec[_]
type constructor, is called the fixpoint data type. In Matryoshka, it is calledFix
. - The
out
function, of the formRec[A] => A
, is called an algebra. - The
into
function, of the formA => Rec[A]
, is called a coalgebra. - The
fold
function, which reiterates the algebraRec[A] => A
into a functionInd[Rec] => A
, is called a catamorphism. - A function like
unfold
, which reiterates the coalgebraA => Rec[A]
into a functionA => Ind[Rec]
, is called anamorphism.
Some of this jargon can be traced back to the academic research on recursion schemes, and by convention it is used in libraries, like Matryoshka, that implement recursion schemes.
Additional reading
- Gibbon’s Origami Programming introduces the main recursion schemes, apart from folds and unfolds, using diverse examples such as sorting algorithms.
- Zainab Ali’s talk, which inspired this post, applies the techniques we used above to another data type, decision trees, which are like binary trees but with more detail in each node or leaf.
- Rob Norris presentation, about how to store academic genealogy trees (another recursive data type) in a relational database (Sep. 2016).