Thursday, April 23, 2009

Tail Recursion in Scala

Earlier this week, I gave a talk about Scala with Bill Venners and David Pollak. In my slides, I had an example of a factorial function in Scala and in Java. I made that point that not only is Scala usually almost as fast as Java, but sometimes it is even faster. One reason for this is that the Scala compiler can perform tail recursion optimizations on your code. This is something that has been proposed for Java as well as other languages, but has its downside as well.

Back to Scala... Bill also had a slide that showed a factorial function in Scala, though it was much simpler as it made use of Scala's implicit conversions between integers and BigInts. Afterwards, Bill and I were talking, and he pointed that he did not think the Scala compiler would be able to optimize either my code or his code. Here is code that is very similar to Bill's code (I don't have the exact code handy, so I recreated it from memory):


def factorial(n:BigInt):BigInt={
if (n == 0) 1
else n*factorial(n-1)
}

Why can't this be optimized as a tail call? Instead of explaining it myself, I will introduce the hero of this story, Daniel Spiewak (check out his blog or follow him on Twitter.) Bill emailed Daniel about his suspicion, and Daniel confirmed it :
You're quite correct: this version of factorial is not tail recursive.
I find it sometimes helps to split the last expression up into atomic
operations, much as the Scala parser would:


else {
val t1 = n - 1
val t2 = factorial(t1)
n * t2
}

It is possible to devise a tail-recursive factorial. In fact, it
might be possible to convert any recursive function into a
tail-recursive equivalent, but I don't have a proof for that handy.
Back to factorial:


def factorial(n: Int) = {
def loop(n: Int, acc: Int): Int = if (n <= 0) acc else loop(n - 1, acc * n)
loop(n, 1)
}

I'm cheating a bit by taking advantage of the fact that multiplication
is commutative, but this is the general idea. Some functions are
naturally tail recursive, but for those which are not, this
"accumulator pattern" is a fairly easy way to achieve the equivalent
result. Technically, the accumulator must be combined with an
inversion of the computation direction (bottom-up rather than
top-down) to generate a truly-equivalent tail-recursive function, but
it is not necessary for examples like this one.

This accumulator pattern is a great tip. To drive home the difference this produces in the behavior of the Scala compiler, it is useful to use our good 'ol friend javap. For the original code, here is the output:

public scala.BigInt factorial(scala.BigInt);
Code:
0: aload_1
1: iconst_0
2: invokestatic #27; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
5: invokestatic #31; //Method scala/runtime/BoxesRunTime.equals:(Ljava/lang/Object;Ljava/lang/Object;)Z
8: ifeq 21
11: getstatic #36; //Field scala/BigInt$.MODULE$:Lscala/BigInt$;
14: iconst_1
15: invokevirtual #40; //Method scala/BigInt$.int2bigInt:(I)Lscala/BigInt;
18: goto 40
21: aload_1
22: aload_0
23: aload_1
24: getstatic #36; //Field scala/BigInt$.MODULE$:Lscala/BigInt$;
27: iconst_1
28: invokevirtual #40; //Method scala/BigInt$.int2bigInt:(I)Lscala/BigInt;
31: invokevirtual #45; //Method scala/BigInt.$minus:(Lscala/BigInt;)Lscala/BigInt;
34: invokevirtual #47; //Method factorial:(Lscala/BigInt;)Lscala/BigInt;
37: invokevirtual #50; //Method scala/BigInt.$times:(Lscala/BigInt;)Lscala/BigInt;
40: areturn

I know this looks scary if you have not used javap much, but if you have then you will recognize that this is a straightforward compilation. You can see the de-sugaring (try scalac -print for a better way to view the de-sugaring that the compiler performs) like the implicit conversions being made (#15 above) and the operators as method invocations (#31 and #37.) Especially important is that you can see the recursive call to factorial on #34, and then a call to BigInt.$times, i.e. multiplication for BigInts.

Now let's rewrite the method using Daniel's suggestions:

def factorial(n:BigInt):BigInt={
def loop(n:BigInt, acc:BigInt):BigInt = if (n <= 0) acc else loop(n-1,acc*n)
loop(n,1)
}

And now when we run javap, things are a lot different (make sure you use the -private flag with javap to get all of the details):

private final scala.BigInt loop$1(scala.BigInt, scala.BigInt);
Code:
0: aload_1
1: getstatic #26; //Field scala/BigInt$.MODULE$:Lscala/BigInt$;
4: iconst_0
5: invokevirtual #30; //Method scala/BigInt$.int2bigInt:(I)Lscala/BigInt;
8: invokevirtual #36; //Method scala/BigInt.$less$eq:(Lscala/BigInt;)Z
11: ifeq 16
14: aload_2
15: areturn
16: aload_1
17: getstatic #26; //Field scala/BigInt$.MODULE$:Lscala/BigInt$;
20: iconst_1
21: invokevirtual #30; //Method scala/BigInt$.int2bigInt:(I)Lscala/BigInt;
24: invokevirtual #40; //Method scala/BigInt.$minus:(Lscala/BigInt;)Lscala/BigInt;
27: aload_2
28: aload_1
29: invokevirtual #43; //Method scala/BigInt.$times:(Lscala/BigInt;)Lscala/BigInt;
32: astore_2
33: astore_1
34: goto 0

public scala.BigInt factorial(scala.BigInt);
Code:
0: aload_0
1: aload_1
2: getstatic #26; //Field scala/BigInt$.MODULE$:Lscala/BigInt$;
5: iconst_1
6: invokevirtual #30; //Method scala/BigInt$.int2bigInt:(I)Lscala/BigInt;
9: invokespecial #51; //Method loop$1:(Lscala/BigInt;Lscala/BigInt;)Lscala/BigInt;
12: areturn

At the bottom is the decompilation of the factorial function. Notice that it calls something called loop$1. That is the compiled version of the inner function loop. It is shown at the top of the decompiler output. The most important thing to notice here is the #34 -- the goto statement. This is a goto-loop. Notice that there is no recursive call back to the loop$1 inside the function, just as there is no call back to factorial inside the factorial function. The compiler removed the tail recursion and replaced it with a goto loop, just as Daniel said it would.

For more proof, you can run the code. On my MacBook, the original loop will blow up in a stack overflow when trying to calculate 60,000! I'm sure it blows up well before then, but you get the point. The optimized version of the code has no problems with 60,000! though it does take a while to calculate.

The one last mystery in all of this for me, goes back to my slides on Scala. In the slides, I used tail recursion as the reason why the factorial code ran faster in Scala than in Java. Clearly I was wrong. The code definitely ran faster on Scala, but I do not know why exactly. Perhaps javap can shine some light on this, too...

3 comments:

  1. In Scala 2.8 you will be able to assert that your code is tail recursive using the @tailrec annotation. I recently wrote a blog entry that shows how you would use this annotation when writing a tail-recursive factorial.

    ReplyDelete
  2. These are the nice codes..really too brilliant Thanks for this useful post.

    ReplyDelete
  3. Anonymous12:17 AM

    int fac(int n,int res)
    {if(!n) return 1;
    else{
    res =res*n;
    return fac(n-1,res);
    }

    ReplyDelete