Tail Recursion
That little compiler optimization for better recursion

28 March 2017

For the last month, I’ve been learning more about functional programming. To try to get a better idea for what it’s all about, I’ve been working my way though the Coursera course Functional Programming Principles in Scala. One of the things that requires a bit of a mindset shift for me is that in functional programming, they tend to use recursion extensively.

Do you want to sort a list? Recursively sort the left and right sides and merge the result. You want to iterate over a list of elements? Do whatever you want to do on the first element, then call the function recursively on the rest of the list.

Why Would you use Recursion to Iterate?

Conceptually speaking, functional programming is trying to take more inspiration from pure maths than other programming paradigms do. This means that you should write your function more like it would appear in a mathematics textbook. Mathematics textbooks rarely show normal iteration, but it is common to see something like this:

\[ f(n) = \begin{cases} 1 &\mbox{if } n \leq 1\\ n \times f(n-1) &\mbox{otherwise} \end{cases} \]

It’s not much of a stretch to see why it would be considered more readable to implement this function like so:

def f(n:Int): Long = {
    if (n <= 1) 1
    else n * f(n-1)
}

This closer coupling of ideas between maths and standard implementations in functional programming is, in my opinion, one of the reasons that functional programming can be so useful in math-heavy domains like science and engineering.

Stack Frames, and the Dreaded Stack Overflow

Stack Overflow, as a website, is lovely. Stack overflows as a programming concept are less fun.

To understand what a stack overflow is, and why it’s relevant here, you first need to understand what the call stack is. The call stack is a data structure hiding in the background of most programming languages, keeping track of all of a function’s local variables, and where the program needs to return to once the function is finished. You’ll often encounter it when you’re debugging, because it can be convenient to step up and down the call stack after you’ve hit a breakpoint to see how you got there.

It is easiest to explain the call stack is with an example. Say you have the following program:

1  def f(n:Int): Long = {
2    if (n <= 1) 1
3    else n * f(n-1)
4  }
5
6  val result = f(3)

The program starts, and you get the inital stack frame with your main function. It knows that it needs memory for val result.

-------------------
MAIN
Long result
Return to line 0
-------------------

We get to line 6, and it calls f for the first time, so we add another frame onto the stack with the information about the call. The data we pass through, 3 in this case, is also copied into the new stack frame.

-------------------
MAIN
Long result
Return to line 0
-------------------
f
Int n = 3
Return to line 6
-------------------

This continues, with line 3 saying we should call f again. Calling a new function means a new stack frame, since it needs to handle its values independently from the previous call.

-------------------
MAIN
Long result
Return to line 0
-------------------
f
Int n = 3
Return to line 6
-------------------
f
Int n = 2
Return to line 3
-------------------

And once more.

-------------------
MAIN
Long result
Return to line 0
-------------------
f
Int n = 3
Return to line 6
-------------------
f
Int n = 2
Return to line 3
-------------------
f
Int n = 1
Return to line 3
-------------------

When n=1, we hit line 2, and the stack starts decreasing in size again by returning a value of 1.

-------------------
MAIN
Long result
Return to line 0
-------------------
f
Int n = 3
Return to line 6
-------------------
f
Int n = 2
Return to line 3
-------------------
Last return value = 1

Now that we’ve returned 1, we have enough information to return on the next level. 1 time 2 is 2, so we pop another frame off the call stack and return 2.

-------------------
MAIN
Long result
Return to line 0
-------------------
f
Int n = 3
Return to line 6
-------------------
Last return value = 2

And so on.

-------------------
MAIN
Long result = 6
Return to line 0
-------------------
Last return value = 6

The last stack frame ends, and so the program ends. The value of result is 6 when the program ends, but we don’t do anything with it so it’s swallowed into the void as the operating system takes back the memory it gave to our program.

In this example, our call stack got to be 4 stack frames high when it was at its highest. The stack can generally get as large as it needs to be in normal programs, but it does have a limited size determined by your programming language, compiler, and operating system. Experimentally, with this particular program in Scala, I can call it with a depth of about 12000 nested calls before I start getting stack overflow problems. For a real program with more than one variable per frame, expect this number to be much lower than 12000, but probably still fairly high.

When you run out of space on the stack, your program crashes, and we call it a stack overflow. As you can imagine, if recursion is your main way of expressing algorithms and iterating over lists, this may become a problem if you have any particularly long lists.

Tail Recursion to the Rescue!

Tail recursion is a certain way of writing your recursion such that the compiler can reuse stack frames. For a function to be considered tail recursive, there is only one criteria: the function must call another function as its last action.

Our example function is not tail recursive, because the last thing it does when it calls f is to multiply by n. You could think of the evaluation above mathematically like so:

f(3)
= 3 * f(2)
= 3 * (2 * f(1))
= 3 * (2 * 1)
= 3 * 2
= 6

The 3 needs to stay around right until the end. Let’s say we rewrote this function such that the 3 was actually passed in as a parameter to the function.

1  def f(n:Int, acc:Long): Long = {
2    if (n <= 1) acc
3    else f(n-1, n*acc)
4  }
5
6  val result = f(3, 1)
7  //acc is short for accumulator

This is functionally equivalent, but if you evaluate it mathematically you don’t get the same telescoping that you got earlier. This is because the last thing that the function does is call another function call, and nothing else.

f(3, 1)
= f(2, 3)
= f(1, 6)
= 6

The compiler optimization is that the previous stack frame does not need to be kept around, so we can avoid running into stack overflow issues.

As before, our stack starts with the stack frame for the main function.

-------------------
MAIN
Long result
Return to line 0
-------------------

And also, as before, when we call f for the first time, we get a new stack frame with all of the parameters passed in.

-------------------
MAIN
Long result
Return to line 0
-------------------
f
Int n = 3
Int acc = 1
Return to line 6
-------------------

However, this time when we get to line 3, we are done using all of the variables in our stack frame except for the return address. We keep the same return address we had before, but update the existing stack frame to have our new values for n and acc.

-------------------
MAIN
Long result
Return to line 0
-------------------
f
Int n = 2
Int acc = 3
Return to line 6
-------------------

This repeats until we get to n = 1.

-------------------
MAIN
Long result
Return to line 0
-------------------
f
Int n = 1
Int acc = 6
Return to line 6
-------------------

By the time we get back to the main stack frame, we get the same result as if nothing was done differently from our original implementation.

-------------------
MAIN
Long result = 6
Return to line 0
-------------------
Last return value = 6

This function won’t need to use a larger and larger call stack as the value for n grows. That’s how functional programmers get away with writing code that appears to recurse so deeply without getting stack overflows.

Which Lanuages Do This?

While I would like to say that this is everywhere, and you can go use it, the reality is that some languages support this optimization and some do not. More generally, the optimization is called “tail call optimization”, so if you want to find out if your language of choice supports it Google some combination of “tail call optimization” and your language name and you’ll get there.

It’s common in Scala, which is where I learned about it. JavaScript added support for it in the ES2015 standard, but it looks like the actual implementation in browsers is lagging a bit behind.

Many imperitive languages deliberately avoid supporting tail recursion for various reasons. Python’s creator argues that it makes the call stack more difficult for programmers to reason about. Rust’s creator argues that, while he loves tail recursion, it doesn’t work well with Rust’s model of memory management, and it doesn’t work well with tools originally made to be used with C that Rust users are also using, such as debuggers and profilers.

So that’s tail recursion. It might not be necessary for the type of programs you’re writing, but maybe someday you’ll hit a problem that can be solved very cleanly using deep recursion, and then you’ll be glad to know it exists.


If you liked this article, please share it on Twitter, Facebook, Google+, or by using the Permalink.


You can send me comments on this post at justin@worthe-it.co.za, or @JWorthe.


More on Worthe It

Previous Post

15 Mar 2017

A careful consideration of the placement of QR codes
Next Post

09 Apr 2017

Automated deployments, static websites, and a little bit of FTP
Latest Post

14 Aug 2017

A retrospective on a Rust audio signal processing program I wrote
Browse the Blog Archive

16 Dec 2014 - 14 Aug 2017

See all of the stuff I've written and put on this site.