To tame recursion, Scala catches it by its tail.
Recursion, a powerful technique in programming, allows one to build up the solution to a problem using the solutions to sub-problems. Recursive functions call themselves, an apparent paradox. Recursion can be a bit bewildering if you’ve never encountered it, but once we get a handle on it, its expressive power can serve us well. In this article we’ll look at an example of writing recursive methods in Scala, discuss issues with simply using this powerful feature, and then look at the solutions Scala offers to the problems.
Computing the factorial, which is a product of all integers from 1 to a given number, has a familiar recursive solution. We can easily write a recursive function in Scala to compute the factorial of a number n.
def factorial(n : BigInt) : BigInt = | |
if(n == 1) 1 else n * factorial(n - 1) |
Since the factorial of a number can be pretty large, we use a BigInt rather than an Int. Given a number, we compute the factorial for it using a recursive call to the factorial function. The factorial for 1 is the number itself. For all numbers greater than 1 the value is computed as the number times the factorial of the number minus one. In this example we assume the given number is greater than or equal to 1.
Let’s use this function to compute the factorial of a small number, like 5.
println(factorial(5)) //120 |
That was easy. And the code is concise and expressive. On first sight, the recursive solution looks quite desirable, so why would we not want to use solutions like this wherever possible? But appearance can be deceiving. While this code worked well for this small argument, it will unfortunately not yield the desired result for a large argument. For example, try invoking the function with 10,000 as a parameter.
try { | |
factorial(10000) | |
} catch { | |
case ex => println(ex) //java.lang.StackOverflowError | |
} |
As the code runs, each call to the function will push the parameter onto the call stack. The stack will rapidly build up as each call continues to hold its space on the stack while it waits for the subsequent call to complete. We will soon run into a StackOverflowError exception.
Recursive solutions are fine for small parameters but quickly turn unusable as the parameter size grows. This can force us to fall back on iterative solutions, like:
def factorial(n : BigInt) : BigInt = { | |
var result = BigInt(1) | |
for(i <- BigInt(1) to n) result *= i | |
result | |
} |
The iterative solution will yield the right results for both small and large parameters:
println(factorial(5)) | |
println(factorial(10000)) |
The iterative version, however, does not have the charm of recursion. It’s relatively verbose and makes use of a mutable local variable that the recursive solution nicely avoided. You can avoid the mutable variable in the iterative solution by using other functions like foldLeft. In any case, now that we have a taste of recursion, it’d be nice to use it without the negative consequences.
Scala makes that possible, as we will now see.
The solution is highlighted in a fantastic book called Structure and Interpretation of Computer Programs (SICP) by Abelson and Sussman (The MIT Press, 1996, Cambridge, Massachusetts). The key is to think differently about the code that we write and the code that we execute.
The first solution we saw is a recursive procedure (structured as a recursion in code) and a recursive process (exercised as a recursion at runtime). The second solution is an iterative procedure and an iterative process, that is written as iteration and runs the same way.
It would be great if we could have the best of both worlds. That is, if we could write a solution as a recursive procedure, and then, using compilation techniques, transform the code so it runs as an iterative process. This is the computing equivalent of having a metabolism we can only dream of, letting us eat all we want without adding an extra pound around the midsection. The midsection here is the call stack created during the execution of code.
Let’s trim the fat from the recursive procedure’s midsection.
Let’s examine that original recursive call more closely.
if(n == 1) 1 else n * factorial(n - 1) |
The obvious detail here is that the function calls itself, but the thing to notice is that the last operation performed in the call is not the call to the function itself. The multiply operation has to be performed once the recursive call returns. As a result, the stack that is built has to grow while we wait for the recursive call to complete and for the multiply operation to exercise. That’s where the fat comes from.
If we can structure the recursive call so that the final operation we perform is not the multiply but the call to the function itself, then a smart compiler (like Scala’s of course) can make use of that to transform the recursion into a simple iteration under the covers—and eliminate the fat. Let’s examine this with another example before we return to the factorial problem.
object Sample { | |
def madMethod1(n : Int) : Int = | |
if(n == 1) throw new RuntimeException("mad") else 1 * madMethod1(n - 1) | |
// | |
def madMethod2(n : Int) : Int = | |
if(n == 1) throw new RuntimeException("mad") else madMethod2(n - 1) | |
// | |
def main(args : Array[String]) = { | |
try { | |
madMethod1(5) | |
} catch { | |
case ex => ex.printStackTrace | |
} | |
| |
try { | |
madMethod2(5) | |
} catch { | |
case ex => ex.printStackTrace | |
} | |
} | |
} |
The Sample singleton has two methods, madMethod1 and madMethod2 both of which throw an exception. The last operation in the madmethod1 is the multiply operation, but the last operation in the madMethod2 is a call to itself, making it tail recursive. Tail recursive methods get a special treatment by the Scala compiler. They’re transformed into simple iteration under the covers—just what we were looking for. If we run the sample, we can see the difference in execution.
The call to madMethod1 results in a stack trace that is five levels deep, as we’d expect.
java.lang.RuntimeException: mad | |
at Main$.madMethod1(Sample.scala:3) | |
at Main$.madMethod1(Sample.scala:3) | |
at Main$.madMethod1(Sample.scala:3) | |
at Main$.madMethod1(Sample.scala:3) | |
at Main$.madMethod1(Sample.scala:3) | |
at Main$.main(Sample.scala:10) | |
at Main.main(Sample.scala) |
The recursive calls to madMethod2, on the other hand, stays at one level deep, as its stack trace reveals.
java.lang.RuntimeException: mad | |
at Main$.madMethod2(Sample.scala:6) | |
at Main$.main(Sample.scala:16) | |
at Main.main(Sample.scala) |
This is due to the special treatment I mentioned: the Scala compiler transforms madMethod2 into a simple iteration while it leaves madMethod1 as a recursion. You can take a peek at this by compiling the code using scalac Sample.scala and running a javap -c Sample$ command on the generated bytecode.
public int madMethod1(int); | |
Code: | |
0: iload_1 | |
1: iconst_1 | |
2: if_icmpne 15 | |
5: new #17; //class java/lang/RuntimeException | |
8: dup | |
9: ldc #19; //String mad | |
11: invokespecial #22; | |
//Method java/lang/RuntimeException."<init>":(Ljava/lang/String;)V | |
14: athrow | |
15: iconst_1 | |
16: aload_0 | |
17: iload_1 | |
18: iconst_1 | |
19: isub | |
20: invokevirtual #24; //Method madMethod1:(I)I | |
23: imul | |
24: ireturn |
public int madMethod2(int); | |
Code: | |
0: iload_1 | |
1: iconst_1 | |
2: if_icmpne 15 | |
5: new #17; //class java/lang/RuntimeException | |
8: dup | |
9: ldc #19; //String mad | |
11: invokespecial #22; | |
//Method java/lang/RuntimeException."<init>":(Ljava/lang/String;)V | |
14: athrow | |
15: iload_1 | |
16: iconst_1 | |
17: isub | |
18:istore_1 | |
19: goto 0 |
The invokeVirtual in the madMethod1 shows the recursive call, and that’s been replaced with an iterative call at the end of the madMethod2.
Writing a recursion as tail recursion gives us the maximum benefit: expressive, concise code that does not tax the stack and so won’t result in a StackOverflowError.
Let’s shift our focus to the factorial example. We have the given number on hand and we need to multiply that value by the result of the recursive call to the method with a parameter of the given number minus one. The challenge now is to write it as a tail recursion instead of a regular recursion.
A straightforward way to do this is to add a parameter. Rather than holding onto the partial result in the current method invocation, we can pass on the partial result as a parameter to the method so it can perform the computation, like so:
def factorial(fact : BigInt, n : BigInt) : BigInt = | |
if(n == 1) fact else factorial(fact * n, n - 1) |
This version of the factorial method takes two parameters instead of one. The first parameter is the partial computation of the factorial, and the second parameter is the number for which we want to compute the factorial. If the given number is 1, we simply return the value of the partial result fact. Otherwise, we multiply the value of fact by the given number, and then send off to the recursive function. Since the last expression we perform is the call to the method itself, this implementation is tail recursive and will enjoy the tail call optimization.
Let’s invoke this method for the two parameter values we tried with the earlier implementations.
println(factorial(1, 5)) //120 | |
// | |
factorial(1, 10000) | |
println("no exception") //no exception |
The tail recursive version produces the desired result and is able to handle parameters of large size as well.
By running this example we can see that the optimization took effect. For a compile-time confirmation of such optimization, Scala provides a special annotation that you can place on the recursive method.
@scala.annotation.tailrec | |
def factorial(fact : BigInt, n : BigInt) : BigInt = | |
if(n == 1) fact else factorial(fact * n, n - 1) |
If the method is not really tail recursive, then this annotation directs the Scala compiler to give us a compilation error.
We benefited from the tail call optimization but we lost a bit of clarity in code. Now, to make use of this method, the caller has to pass two parameters instead of one. In addition to this burden, this opens the door for errors. Someone could send an incorrect value, like 0 as the first parameter, and jeopardize the correct execution of the method.
The need to send two parameters instead of one is an implementation detail to facilitate tail recursion. In programming we’ve learned to encapsulate or hide implementation details. We can achieve a lightweight encapsulation by defining this function within another function that takes the necessary single parameter.
def factorial(n : BigInt) = { | |
| |
@scala.annotation.tailrec | |
def factorial(fact : BigInt, n : BigInt) : BigInt = | |
if(n == 1) fact else factorial(fact * n, n - 1) | |
| |
factorial(1, n) | |
} |
The two-parameter method is hidden from the view of callers. The users of this version can invoke the single-parameter method, which in turn invokes the two-parameter tail recursive implementation of the factorial method.
Recursion is a pretty cool programming technique. Scala not only supports recursion but optimizes the resource usage. This allows us to make good use of this powerful technique without the worries of stack overflows.
Dr. Venkat Subramaniam is an award-winning author, founder of Agile Developer, Inc., and an adjunct faculty at the University of Houston.
He has trained and mentored thousands of software developers in the US, Canada, Europe, and Asia, and is a regularly invited speaker at several international conferences. Venkat helps his clients effectively apply and succeed with agile practices on their software projects.
Venkat is the author of .NET Gotchas, the coauthor of 2007 Jolt Productivity Award winning Practices of an Agile Developer, the author of Programming Groovy: Dynamic Productivity for the Java Developer and Programming Scala: Tackle Multi-Core Complexity on the Java Virtual Machine. His latest book is Programming Concurrency on the JVM: Mastering Synchronization, STM, and Actors.
This series started in the September 2011 issue and has been running continuously since then. If you’d like to read the whole series, here are the links to the articles published so far:
9/11: The Elegance of Scala
11/11: Cute Classes and Pure OO
1/12: Working with Collections
3/12: Pattern Matching
Send the author your feedback or discuss the article in the magazine forum.