Back to Scala... Bill also had a slide that showed a factorial function in Scala, though it was much simpler as it made use of Scala's implicit conversions between integers and BigInts. Afterwards, Bill and I were talking, and he pointed that he did not think the Scala compiler would be able to optimize either my code or his code. Here is code that is very similar to Bill's code (I don't have the exact code handy, so I recreated it from memory):

def factorial(n:BigInt):BigInt={

if (n == 0) 1

else n*factorial(n-1)

}

Why can't this be optimized as a tail call? Instead of explaining it myself, I will introduce the hero of this story, Daniel Spiewak (check out his blog or follow him on Twitter.) Bill emailed Daniel about his suspicion, and Daniel confirmed it :

*You're quite correct: this version of factorial is not tail recursive.*

I find it sometimes helps to split the last expression up into atomic

operations, much as the Scala parser would:

I find it sometimes helps to split the last expression up into atomic

operations, much as the Scala parser would:

else {

val t1 = n - 1

val t2 = factorial(t1)

n * t2

}

*It is possible to devise a tail-recursive factorial. In fact, it*

might be possible to convert any recursive function into a

tail-recursive equivalent, but I don't have a proof for that handy.

Back to factorial:

might be possible to convert any recursive function into a

tail-recursive equivalent, but I don't have a proof for that handy.

Back to factorial:

def factorial(n: Int) = {

def loop(n: Int, acc: Int): Int = if (n <= 0) acc else loop(n - 1, acc * n)

loop(n, 1)

}

*I'm cheating a bit by taking advantage of the fact that multiplication*

is commutative, but this is the general idea. Some functions are

naturally tail recursive, but for those which are not, this

"accumulator pattern" is a fairly easy way to achieve the equivalent

result. Technically, the accumulator must be combined with an

inversion of the computation direction (bottom-up rather than

top-down) to generate a truly-equivalent tail-recursive function, but

it is not necessary for examples like this one.

is commutative, but this is the general idea. Some functions are

naturally tail recursive, but for those which are not, this

"accumulator pattern" is a fairly easy way to achieve the equivalent

result. Technically, the accumulator must be combined with an

inversion of the computation direction (bottom-up rather than

top-down) to generate a truly-equivalent tail-recursive function, but

it is not necessary for examples like this one.

This accumulator pattern is a great tip. To drive home the difference this produces in the behavior of the Scala compiler, it is useful to use our good 'ol friend javap. For the original code, here is the output:

public scala.BigInt factorial(scala.BigInt);

Code:

0: aload_1

1: iconst_0

2: invokestatic #27; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;

5: invokestatic #31; //Method scala/runtime/BoxesRunTime.equals:(Ljava/lang/Object;Ljava/lang/Object;)Z

8: ifeq 21

11: getstatic #36; //Field scala/BigInt$.MODULE$:Lscala/BigInt$;

14: iconst_1

15: invokevirtual #40; //Method scala/BigInt$.int2bigInt:(I)Lscala/BigInt;

18: goto 40

21: aload_1

22: aload_0

23: aload_1

24: getstatic #36; //Field scala/BigInt$.MODULE$:Lscala/BigInt$;

27: iconst_1

28: invokevirtual #40; //Method scala/BigInt$.int2bigInt:(I)Lscala/BigInt;

31: invokevirtual #45; //Method scala/BigInt.$minus:(Lscala/BigInt;)Lscala/BigInt;

34: invokevirtual #47; //Method factorial:(Lscala/BigInt;)Lscala/BigInt;

37: invokevirtual #50; //Method scala/BigInt.$times:(Lscala/BigInt;)Lscala/BigInt;

40: areturn

I know this looks scary if you have not used javap much, but if you have then you will recognize that this is a straightforward compilation. You can see the de-sugaring (try scalac -print for a better way to view the de-sugaring that the compiler performs) like the implicit conversions being made (#15 above) and the operators as method invocations (#31 and #37.) Especially important is that you can see the recursive call to factorial on #34, and then a call to BigInt.$times, i.e. multiplication for BigInts.

Now let's rewrite the method using Daniel's suggestions:

def factorial(n:BigInt):BigInt={

def loop(n:BigInt, acc:BigInt):BigInt = if (n <= 0) acc else loop(n-1,acc*n)

loop(n,1)

}

And now when we run javap, things are a lot different (make sure you use the -private flag with javap to get all of the details):

private final scala.BigInt loop$1(scala.BigInt, scala.BigInt);

Code:

0: aload_1

1: getstatic #26; //Field scala/BigInt$.MODULE$:Lscala/BigInt$;

4: iconst_0

5: invokevirtual #30; //Method scala/BigInt$.int2bigInt:(I)Lscala/BigInt;

8: invokevirtual #36; //Method scala/BigInt.$less$eq:(Lscala/BigInt;)Z

11: ifeq 16

14: aload_2

15: areturn

16: aload_1

17: getstatic #26; //Field scala/BigInt$.MODULE$:Lscala/BigInt$;

20: iconst_1

21: invokevirtual #30; //Method scala/BigInt$.int2bigInt:(I)Lscala/BigInt;

24: invokevirtual #40; //Method scala/BigInt.$minus:(Lscala/BigInt;)Lscala/BigInt;

27: aload_2

28: aload_1

29: invokevirtual #43; //Method scala/BigInt.$times:(Lscala/BigInt;)Lscala/BigInt;

32: astore_2

33: astore_1

34: goto 0

public scala.BigInt factorial(scala.BigInt);

Code:

0: aload_0

1: aload_1

2: getstatic #26; //Field scala/BigInt$.MODULE$:Lscala/BigInt$;

5: iconst_1

6: invokevirtual #30; //Method scala/BigInt$.int2bigInt:(I)Lscala/BigInt;

9: invokespecial #51; //Method loop$1:(Lscala/BigInt;Lscala/BigInt;)Lscala/BigInt;

12: areturn

At the bottom is the decompilation of the factorial function. Notice that it calls something called loop$1. That is the compiled version of the inner function loop. It is shown at the top of the decompiler output. The most important thing to notice here is the #34 -- the goto statement. This is a goto-loop. Notice that there is no recursive call back to the loop$1 inside the function, just as there is no call back to factorial inside the factorial function. The compiler removed the tail recursion and replaced it with a goto loop, just as Daniel said it would.

For more proof, you can run the code. On my MacBook, the original loop will blow up in a stack overflow when trying to calculate 60,000! I'm sure it blows up well before then, but you get the point. The optimized version of the code has no problems with 60,000! though it does take a while to calculate.

The one last mystery in all of this for me, goes back to my slides on Scala. In the slides, I used tail recursion as the reason why the factorial code ran faster in Scala than in Java. Clearly I was wrong. The code definitely ran faster on Scala, but I do not know why exactly. Perhaps javap can shine some light on this, too...

## 3 comments:

In Scala 2.8 you will be able to assert that your code is tail recursive using the @tailrec annotation. I recently wrote a blog entry that shows how you would use this annotation when writing a tail-recursive factorial.

These are the nice codes..really too brilliant Thanks for this useful post.

int fac(int n,int res)

{if(!n) return 1;

else{

res =res*n;

return fac(n-1,res);

}

Post a Comment