# A lightweight introduction to Recursion Schemes in Scala

- by Diego Alonso Blas
- •
- February 28, 2018
- •
- functional programming• scala• matryoshka• recursion schemes
- |
- 28 minutes to read.

Last December, I attended a talk by Zainab Ali entitled Topiary and the art of origami. This talk focused how to write a decision-tree-learning program in Scala. Zainab showed how, from some non-recursive functions, one can raise a decision tree data type, a learning algorithm, and a predicting algorithm, each in one line. She could do this by using Matryoshka, a library that implements recursion schemes in Scala.

Recursion schemes are relevant in functional programming. Contrary to what the jargon suggests, they are not too difficult to understand. Here is a small-steps introduction to basic recursion schemes in Scala.

### Lists

A simple and commonplace recursive data type is the List. Scala provides a `List[A]`

type, which is generic on the type `A`

of elements. Since recursion schemes work just as well with non-generic collections, for convenience I am using a numbers-only list type instead:

```
sealed trait List
case object Nil extends List
case class Cons(head: Int, tail: List) extends List
```

This List type is recursive because the `Cons`

class has a value parameter which is a List itself. Here is the full code of this blog post example (with minor changes).

#### Step 1: list fold

How do you add up all the numbers in a list? Since most functional languages have no `while`

loops, one has to use recursion:

```
def sum(list: List): Int =
list match {
case Nil => 0
case Cons(head, tail) => head + sum(tail)
}
```

However, recursion should not be abused: it is too powerful and confusing, and can make programs difficult to read. This is why functional programming aims at hiding it behind higher-order functions like `map`

or `fold`

. Here is a `fold`

for our List type:

```
def fold( zero: Int, op: (Int, Int) => Int)(list: List): Int =
list match {
case Nil => zero
case Cons(head, tail) => op( head, fold(zero, op)(tail) )
}
```

Apart from the list, `fold`

takes as parameters the result for `Nil`

, and a function `op`

used to combine the head of a `Cons`

with the result of the recursive `fold`

on the tail. We can now write the `sum`

as a special case of `fold`

, like this:

```
def add(a: Int, b: Int) = a + b
def sum(list) = fold(0, add)(list)
```

Now there is no recursion in the code of `sum`

: it has been moved out of `sum`

and into `fold`

.

#### Step 2: list `unfold`

We can write a `digits`

function, to get the list of digits (`[0-9]`

) of a number, also using recursion:

```
def digits(seed: Int): List =
if (seed == 0)
Nil
else
Cons(seed % 10, digits(seed / 10) )
```

Unlike in `sum`

, in `digits`

we use recursion to build a list. Nevertheless, like `sum`

with `fold`

, we can also take recursion out of `digits`

and put into a higher-order function called `unfold`

:

```
def unfold(
isEnd: Int => Boolean, op: Int => (Int, Int) )( seed: Int
): List =
if (isEnd(seed))
Nil
else {
val (head, next) = op(seed)
Cons(head, unfold(isEnd, op)(next) )
}
def isZero(x: Int): Boolean = x == 0
def div10(x: Int): (Int, Int) = (x % 10, x / 10)
def digits(seed) = unfold(isZero, div10)(seed)
```

Apart from the seed, `unfold`

takes as parameter a predicate to mark the end, and a function `Int => (Int, Int)`

to split a seed into the value at the head, and the seed from which to build the tail.

#### Step 3: The optional cons type

`fold`

and `unfold`

are opposed: this grows a list from a seed, that reduces it to a result. However, there is a sort of symmetry between their definitions. To see this symmetry, we use an `OCons`

type alias:

```
type OCons = Option[(Int, Int)]
```

This `OCons`

alias can describe an optional `Cons`

, intuitively, if there is a recursive case (`Some(Int, Int)`

), or a base case (`None)`

. We can write `fold`

and `unfold`

to work around this type:

```
def fold(out: OCons => Int)(list: List): Int =
list match {
case Nil => out( None)
case Cons(head, tail) => out( Some( (head, fold(out)(tail)) ))
}
def unfold( into: Int => OCons)(seed: Int): List =
into(seed) match {
case None => Nil
case Some((head, next)) => Cons(head, unfold(into)(next) )
}
```

Both `fold`

and `unfold`

now take as parameter a function on `OCons`

, but with reversed types: in `fold`

the function gets a number *out* of an `OCons`

, whereas in `unfold`

it splits a number *into* an `OCons`

.

The new types of `fold`

and `unfold`

show another similarity:

```
fold: (OCons => Int) => (List => Int)
unfold: (Int => OCons) => (Int => List)
```

So, we can see `fold`

and `unfold`

each as lifting a single-step functions, `out`

or `in`

, into a loop that performs that step many times.

#### Step 4: Tree Folds

Another recursive data type is the type of binary trees of numbers:

```
sealed trait Tree
case object Leaf extends Tree
case class Node(left: Tree, top: Int, right: Tree)
```

Unlike lists, the `Tree`

type appears recursively twice in the `Node`

subclass, for the left and right subtrees.

As we did for lists, we can write a `sum`

function for trees, using recursion:

```
def sum(tree: Tree): Int =
tree match {
case Leaf => 0
case Node(ll,top,rr) => sum(ll) + top + sum(rr)
}
```

Unlike the `sum`

for lists, here we have two recursive calls, one per subtree.
Despite this, we can also take the recursion out of `sum`

and into a tree-`fold`

function:

```
def fold( zero: Int, op: (Int, Int, Int) => Int)(tree: Tree): Int =
tree match {
case Leaf =>
zero
case Node(ll,top,rr) =>
op( fold(zero, op)(ll), top, fold(zero, op)(rr) )
}
def add3(a: Int, b: Int, c: Int) = a + b + c
def sum(tree: Tree): Int = fold(0, add3)(tree)
```

#### Step 5: Tree unfolds

We can write a function `digits`

to lay the digits of a large number out in a binary tree, much like the `digits`

function from Step 2.

```
def digits(seed: Int): Tree =
if (seed ==0) Leaf
else {
val (pref, mid, suff) = splitNumber(seed)
Node( digits(pref), mid, digits(suff) )
}
// splitNumbers: split a number's digits in the middle,
// for example, splitNumber(56784197) = (567, 8, 4197)
def splitNumber(seed: Int): (Int, Int, Int) = /*---*/
```

As in the case of a list, we can extract recursion from the `digits`

function into an `unfold`

functions for trees:

```
def unfold(
isEnd: Int => Boolean, op: Int => (Int, Int, Int) )(seed: Int
): Tree =
if (isEnd(seed))
Leaf
else {
val (ll, top, rr) = op(seed)
Node( unfold(isEnd, op)(ll), top, unfold(isEnd, op)(rr) )
}
def digits(seed: Int): Tree = unfold(isEnd, splitNumber)(seed)
```

#### Step 6: The optional node type for

Like the `fold`

and `unfold`

for lists, the `fold`

and `unfold`

for trees are opposed but similar. As we did with `OCons`

, we can show the symmetry between the `fold`

and `unfold`

functions for trees using an `ONode`

type:

```
type ONode = Option[(Int, Int, Int)]
def fold(out: ONode => Int)(tree: Tree): Int =
tree match {
case Leaf =>
out( None )
case Node(ll, top, rr) =>
out( Some( (fold(out)(ll), top, fold(out)(rr) ) ))
}
def unfold(in: Int => ONode)(seed: Int): Tree =
in(seed) match {
case None =>
Leaf
case Some( (ll, top, rr) ) =>
Node( unfold(in)(ll), top, unfold(in)(rr) )
}
```

Here, `ONode`

intuition is if there is recursion or not. It has three instead of two numbers, because in a tree there are two recursive occurrences, not one.

#### Step 7: Using maps

Let us compare the `fold`

-`unfold`

functions for lists based on the `OCons`

alias, with those for trees that use the `ONode`

alias. We can see some similarities between them:

- The
`fold`

functions for lists and trees both take a function to get the number`out`

of the`OCons`

or`ONode`

, respectively. - The
`unfold`

functions both take a function to split a number`into`

an`OCons`

or`ONode`

. - There is a one-to-one translation between the cases of the data structure, and those of the non-recursive type alias.
- A same recursive call,
`fold(out)`

or`unfold(into)`

, is applied to each recursive occurrence of the data type.

To highlight this similarity, we make `OCons`

and `ONode`

generic on a data type `R`

, to mark the recursive positions.

```
type OCons[R] = Option[(Int, R)]
type ONode[R] = Option[(R, Int, R)]
```

At each type `OCons`

or `ONode`

, we can “*apply a same recursive call to each recursive position*” by using a `map`

function:

```
def mapOC[A, B]( fun: A => B, ocons: OCons[A]): OCons[B] =
ocons match {
case None => None
case Some((head, tail)) => Some((head, fun(tail)))
}
def mapON[A, B](fun: A => B, onode: ONode[A]): ONode[B] =
onode match {
case None => None
case Some((ll, top, rr)) => Some((fun(ll), top, fun(rr)))
}
```

These are very similar to the `map`

function for the `Option`

type.

#### Step 8: Aligning Folds

The idea of a one-to-one translation from the cases of each data type to the cases of the `Option`

alias, we can write with an `open`

function for each type:

```
def open(list: List): OCons[List] =
tree match {
case Nil => None
case Cons(head, tail) => Some((head, tail))
}
def open(tree: Tree): ONode[Tree] =
tree match {
case Leaf => None
case Node(ll, top, rr) => Some((ll, top, rr))
}
```

Using the functions `open`

and `map`

functions, we can now express the `fold`

and `unfold`

functions into a single line each:

```
def fold(out: OCons[Int] => Int)(list: List): Int =
out( mapOC(fold(out), open(tree)) )
def fold(out: ONode[Int] => Int)(tree: Tree): Int =
out( mapON(fold(out), open(tree)) )
```

These definitions of `fold`

use a same *map* function to recursively apply fold or unfold to each appearance of the `R`

type parameter.

#### Step 9: Align Unfolds

For its part, the 1:1 translation from the cases of the alias type to the cases of the data type we can write with a function `close`

, for lists and trees.

```
def close(ocons: OCons[List]): List =
ocons match {
case None => Nil
case Some((head, tail)) => Cons(head, tail)
}
def close(onode: ONode[Tree]): Tree =
onode match {
case None => Leaf
case Some((ll, top, rr)) => Node(ll, top, rr)
}
```

Using the functions open, close, and map, we can write `fold`

/`unfold`

with a single line each:

```
def unfold(into: Int => OCons[Int])(seed: Int): List =
close( mapOC( unfold(into), into(seed) ) )
def unfold(into: Int => ONode[Int])(seed: Int): Tree =
close( mapON( unfold(into), into(seed) ) )
```

Again, we use *map* to apply the recursive call to unfold at each appearance of the `R`

type parameter.

#### Step 10: Indirect Recursion for *each* Data Type

So far, we have only focused on hiding direct recursion from functions. Can we do the same with data types? Yes, if we use a form of indirection.

In `fold`

and `unfold`

, the functions `in`

and `out`

functions represent one recursive step. Likewise, we can use the types `OCons[R]`

and `ONode[R]`

to represent one link in a list or tree. The type parameter `R`

gives us the indirection that we need to cut recursion in each data type:

```
case class List_Ind(opt: OCons[List_Ind])
case class Tree_Ind(opt: ONode[Tree_Ind])
```

The `List_Ind`

(resp. `Tree_Ind`

) data type now consists of a single case class with a single parameter `opt`

. The type of `opt`

is the type operator `OCons`

(or `ONode`

) applied to the data type `List_Ind`

(or `Tree_Ind`

) itself. Thus, we have shifted recursion from values to types.

For the `List_Ind`

and `Tree_Ind`

classes, we can write the following `fold`

and `unfold`

functions:

```
def fold(out: OCons[Int] => Int)(list: List_Ind): Int =
out( mapOC( fold(out), list.opt) )
def fold(out: ONode[Int] => Int)(tree: Tree_Ind): Int =
out( mapON( fold(out), tree.opt) )
def unfold(into: Int => OCons[Int])(seed: Int): List_Ind =
List_Ind( mapOC(unfold(into), into(seed) ))
def unfold(into: Int => ONode[Int])(seed: Int): Tree_Ind =
Tree_Ind( mapON(unfold(into), into(seed) ))
```

The two `fold`

functions (resp. `unfold`

) are almost the same, and they only differ in the type operator (`OCons`

vs `ONode`

), and in the map function (`mapOC`

vs `mapON`

) for it.
Note that, since `List_Ind`

and `Tree_Ind`

already use `OCons`

and `ONode`

, we no longer need the auxiliary functions `open`

and `close`

.

#### Step 11: Indirect recursion for *all* data types

Now we can unify `List_Ind`

and `Tree_Ind`

into a single `Ind`

case class, which generalizes indirect recursion in the data types.

The classes `List_Ind`

and `Tree_Ind`

above only differ in the type operator (`OCons`

or `ONode`

) used in the type of the `opt`

field. To unify them, we need to extract that type operator as a type parameter, which we call as `Rec`

.

```
case class Ind[ Rec[_] ]( opt: Rec[ Ind[Rec] ] )
```

The `Rec[_]`

symbol means that `Rec`

generalizes type operators, like `OCons[R]`

or `ONode[R]`

, which themselves take a type parameter `R`

.

We can now turn `List_Ind`

and `Tree_Ind`

each into a special case of `Ind`

:

```
type List_Ind = Ind[OCons]
type Tree_Ind = Ind[ONode]
```

#### Step 12: Unified fold and unfold.

Now, let us unify the `fold`

and `unfold`

functions for `List_Ind`

and `Tree_Ind`

, from Step 10, into a single pair of `fold`

-`unfold`

functions for the `Ind[Rec[_]]`

type. Let us start from the code of this functions.

```
def fold( out: OCons[Int] => Int)(list: List_Ind): Int =
out( mapOC( fold(out), list.opt) )
def fold( out: ONode[Int] => Int)(tree: Tree_Ind): Int =
out( mapON( fold(out), tree.opt) )
def unfold( into: Int => OCons[Int])(seed: Int): List_Ind =
List_Ind( mapOC(unfold(into), into(seed) ))
def unfold( into: Int => ONode[Int])(seed: Int): Tree_Ind =
Tree_Ind( mapON(unfold(into), into(seed) ))
```

To join the two `fold`

(or `unfold`

) functions into a single one, we have to 1) add the generic parameter `Rec[_]`

to the function, replace each appearance of `OCons`

and `ONode`

by `Rec`

; and replace each appearance of `List_Ind`

and `Tree_Ind`

by `Ind[Rec]`

. This yields the following:

```
def fold[ Rec[_] ](out: Rec[Int] => Int)(ind: Ind[Rec]): Int =
out( map( fold(out), ind.opt) )
def unfold[ Rec[_] ](into: Int => Rec[Int])(seed: Int): Ind[Rec] =
Ind( map(unfold(into), into(seed)) )
```

We also have to unify `mapOC`

and `mapON`

into a function `map`

for `Rec[_]`

. However, since `Rec[_]`

is generic, the `map`

must be provided as another parameter of the function. A common way to do so is to wrap the `map`

inside a trait, usually called a Functor, which is also generic on `Rec`

.

```
trait Functor[F[_]] {
def map[A, B](fun: A => B, from: F[A]): F[B]
}
def fold[Rec[_]](
ff: Functor[Rec], out: Rec[Int] => Int)(ind: Ind[F]
): Int =
out( ff.map( fold(ff, out), ind.opt) )
def unfold[Rec[_]](
ff: Functor[Rec], into: Int => Rec[Int])(seed: Int
): Ind =
Ind( ff.map( unfold(into), into(seed)) )
```

Now we can use these functions as a `fold`

and `unfold`

for lists, trees, or any data structure we need.

#### End

After all of the previous steps, what we have is the following:

- We joined two recursive data types, lists and trees, into a single the case class
`Ind[Rec[_]]`

for indirect recursion. - We joined two
`fold`

functions, for lists and trees, into a single`fold`

function for the`Ind`

typeclass. - We joined two
`unfold`

functions, for lists and trees, into a single`unfold`

function for the`Ind`

typeclass.

The class `Ind[Rec[_]]`

is generic on the `Rec[_]`

type constructor, which means that it can represent not just lists or binary trees, but also any directly recursive data type.

### Jargon

The notions and intuitions we used above have some special names.

- A generic type like
`Ind[Rec[_]]`

, that takes another generic type`Rec[_]`

as a type parameter, is called a**higher-kinded**type. - The trait
`Functor`

, used to provide a`map`

function, is a fundamental type-class in`cats`

. - The
`Ind[Rec[_]]`

case class, that captures the recursive application of`Rec[_]`

type constructor, is called the**fixpoint**data type. In Matryoshka, it is called`Fix`

. - The
`out`

function, of the form`Rec[A] => A`

, is called an**algebra**. - The
`into`

function, of the form`A => Rec[A]`

, is called a**coalgebra**. - The
`fold`

function, which reiterates the algebra`Rec[A] => A`

into a function`Ind[Rec] => A`

, is called a**catamorphism**. - A function like
`unfold`

, which reiterates the coalgebra`A => Rec[A]`

into a function`A => Ind[Rec]`

, is called**anamorphism**.

Some of this jargon can be traced back to the academic research on recursion schemes, and by convention it is used in libraries, like Matryoshka, that implement recursion schemes.

### Additional reading

- Gibbon’s
*Origami Programming*introduces the main recursion schemes, apart from folds and unfolds, using diverse examples such as sorting algorithms. - Zainab Ali’s talk, which inspired this post, applies the techniques we used above to another data type, decision trees, which are like binary trees but with more detail in each node or leaf.
- Rob Norris presentation, about how to store academic genealogy trees (another recursive data type) in a relational database (Sep. 2016).