04 - tail recursion

Mondjuk, szeretnénk írni olyan függvényt, ami ciklusban összeadja az első n pozitív egész számot (visszaadhatnánk persze csak simán a számtani összegképletből jövő n*(n+1)/2 értéket is, de most a ciklussal kapcsolatos problémákkal szeretnénk foglalkozni).

Rekurzió és a call stack

Az első ötlet az előző poszt alapján lehet kb. ilyen:

1
2
3
def sum( n: Int ): Int = {
  if( n <= 0 ) 0 else n + sum( n-1 )
}

Ki is próbálhatjuk a ▹ szemantikánkkal:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
sum(5)  if( 5 <= 0 ) 0 else 5 + sum( 5-1 )
        if( false ) 0 else 5 + sum( 5-1 )
        5 + sum( 5-1 )
        5 + sum( 4 )
        5 + { if( 4 <= 0 ) 0 else 4 + sum( 4-1 ) }
        5 + { if( false ) 0 else 4 + sum( 4-1 ) }
        5 + { 4 + sum( 4-1 ) } //note: a kapcsos marad, amíg nincs kiértékelve a belső kifejezés!
        5 + { 4 + sum( 3 ) }
        5 + { 4 + { if( 3 <= 0 ) 0 else 3 + sum( 3-1 ) } }
        5 + { 4 + { if( false ) 0 else 3 + sum( 3-1 ) } }
        5 + { 4 + { 3 + sum( 3-1 ) } } //látványosan ,,hízik'' a kifejezés
        5 + { 4 + { 3 + sum( 2 ) } }
        5 + { 4 + { 3 + { if( 2 <= 0 ) 0 else 2 + sum( 2-1 ) } } }
        5 + { 4 + { 3 + { if( false ) 0 else 2 + sum( 2-1 ) } } }
        5 + { 4 + { 3 + { 2 + sum( 2-1 ) } } }
        5 + { 4 + { 3 + { 2 + sum( 1 ) } } }
        5 + { 4 + { 3 + { 2 + { if( 1 <= 0 ) 0 else 1 + sum( 1-1 ) } } } }
        5 + { 4 + { 3 + { 2 + { if( false ) 0 else 1 + sum( 1-1 ) } } } }
        5 + { 4 + { 3 + { 2 + { 1 + sum( 1-1 ) } } } }
        5 + { 4 + { 3 + { 2 + { 1 + sum( 0 ) } } } }
        5 + { 4 + { 3 + { 2 + { 1 + { if( 0 <= 0 ) 0 else 1 + sum( 0-1 ) } } } } }
        5 + { 4 + { 3 + { 2 + { 1 + { if( true ) 0 else 1 + sum( 0-1 ) } } } } }
        5 + { 4 + { 3 + { 2 + { 1 + { 0 } } } } } // finally, térhetünk vissza
        5 + { 4 + { 3 + { 2 + { 1 + 0 } } } }
        5 + { 4 + { 3 + { 2 + { 1 } } } }
        5 + { 4 + { 3 + { 2 + 1 } } }
        5 + { 4 + { 3 + { 3 } } }
        5 + { 4 + { 3 + 3 } }
        5 + { 4 + { 6 } }
        5 + { 4 + 6 }
        5 + { 10 }
        5 + 10
        15

Amit ebből a ,,papíron'' futtatásbóól látunk:

  • azért tesszük ki függvényhívásnál mindig a kapcsost, és oldjuk fel csak akkor, mikor már egy érték van benne, mert ez így jól modellezi azt, hogy valójában ,,mi is történik'' függvényhíváskor: például látja mindenki, hogy az 5 + { 4 + sum(3) } kifejezést nem tudjuk átírni 9 + sum(3)-ra, mert a függvény kiértékelése a kapcsoson belül még zajlik.
  • Aggasztóan hízik a függvénykifejezés mérete.

Hihetnénk persze azt, hogy ez csak papíron van így, de ha megpróbáljuk kiprintelni sum(20000)-et:

butasum-1

butasum-2

A függvényhívások ilyen mélységű egymásba ágyazása az, amit a call stack már nem bír el. Persze meg lehet növelni JVM argumentummal, de az nem megoldás, alapvetően egy több tízezres mélységű rekurzív hívás olyasmi, amit mindenképp el szeretnénk kerülni.

Itt jön be a képbe a tail call optimization, amit már tulajdonképpen láttunk is.

Tail call optimization, tailrec

Egy rekurzív függvényhívást tail pozíción lévő hívásnak mondunk, ha a hívás értékével eztán már nem csinálunk semmit: imperatív nyelven (java, C stb) ez annyit tesz, hogy return f(...);-fel egyből visszaadjuk, funkcionális nyelven meg kb. azt, hogy se nem helyettesítjük be függvénybe, se nem jön utána további kifejezés-kiértékelés. Maga a függvény tail rekurzív akkor, ha rekurzív és az összes rekurzív hívása tail pozin van.

A sum függvényünk előző implementációja nem ilyen, mert a kifejezés értéke az egyik szálon n+sum(n-1) lesz, azaz miután kiértékeltük rekurzívan a függvényt, még az értékkel csinálunk valamit.

Viszont pl. a for ciklusunk implementációja ilyen volt: a rekurzív hívás, printTizig( i+1 ), az utolsó kifejezés volt a szálon, nem követte semmi. Coincidentally, az nem is töltötte a stacket, legalábbis papíron.

Ha egy rekurzív hívás tail call, akkor a fordítónak alkalma van optimalizálni, és nem generálni függvényhívást: egyszerűen csak be kell állítani az argumentumokat az aktuális függvényük példányában a memóriában arra, amivel most hívnánk, és csak ,,gotozni'' a függvény elejére.

Ennek a módszernek, amit tail call optimizationnak hívnak, vannak persze előnyei és hátrányai:

  • előny: nem tölti a stacket, és valamivel gyorsabb is, hiszen nem kell újrafoglalni memóriát a lokális változóknak, eltárolni a visszatérési címet stb
  • hátrány: nehezebb debugolni, ha arra lenne szükségünk, hogy melyik ,,föntebbi'' példányában a függvénynek mi az ottani lokális változók értéke, azt nem tudjuk előbányászni, mert felülírtuk őket.

Egy pure funkcionális programozási nyelvben alapelem a nagy mélységű rekurzió, ezért az ilyen nyelvek általában támogatják a tail call optimalizálást; a Scala is. (A Java egyébként nem, ott tail recursive függvények is simán dobnak egy stacktracet, ha túl mély a rekurzió.) Persze ehhez tail recursive alakra kell átalakítsuk a függvényünket, ami vagy sikerül, vagy nem. Nézzük meg pl. az összegző függvényt tail rekurzívan:

1
2
3
4
5
6
7
8
def tailSum(n: Int, s: Int): Int = {
  if( n<=0 ) s
  else tailSum(n-1, s+n)
}

def tailSum(n: Int): Int = tailSum(n,0)

println( tailSum(20000) ) //prints 200010000

OK, egyrészt tehát bármi is ez, lefut, nem dob stacktrace-t, tényleg tail rekurzív is (az else ág végén van az egyetlen rekurzív hívás). Amit pluszban tanultunk most a Scala nyelvről:

  • van benne function overloading: nevezhetünk el ugyanúgy két függvényt, ha a paraméterlistájuk különbözik. (vagy a paraméterek száma, vagy ha az egyenlő, akkor valahanyadik paraméter típusa eltér)

Hogy megértsük, hogyan működik, próbáljuk kiértékelni a tailSum(5)-öt ismét:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
tailSum(5)  tailSum(5,0)
   if( 5<=0 ) 0 else tailSum(5-1, 0+5)
   if( false ) 0 else tailSum(5-1, 0+5)
   tailSum(5-1, 0+5)
   tailSum(4, 0+5)    //első argumentum kiértékelve
   tailSum(4, 5)      //második argumentum kiértékelve
   if( 4<=0 ) 5 else tailSum(4-1, 5+4)
   if( false ) 5 else tailSum(4-1, 5+4)
   tailSum(4-1, 5+4)
   tailSum(3, 5+4)
   tailSum(3, 9)
   if( 3<=0 ) 9 else tailSum(3-1, 9+3)
   if( false ) 9 else tailSum(3-1, 9+3)
   tailSum(3-1, 9+3)
   tailSum(2, 12)
   if( 2<=0 ) 12 else tailSum(2-1, 12+2)
   if( false ) 12 else tailSum(2-1, 12+2)
   tailSum(2-1, 12+2)
   tailSum(1, 12+2)
   tailSum(1, 14)
   if( 1<=0 ) 14 else tailSum(1-1, 14+1)
   if( false ) 14 else tailSum(1-1, 14+1)
   tailSum(1-1, 14+1)
   tailSum(0, 14+1)
   tailSum(0, 15)
   if( 0<=0 ) 15 else tailSum(0-1, 15+0)
   if( true ) 15 else tailSum(0-1, 15+0)
   15

Nem töltődik a stack!

Amiért pedig ez jó függvény lesz: azt lehet bebizonyítani indukcióval, hogy tailSum(n,s) = s + 1 + 2 + ... + n lesz minden s,n >= 0 egészek esetén. Ezt a következő posztban meg is nézzük, hogy ne legyen ez hosszú.

Kérdések, feladatok

  • Ellenőrizzük, hogy a tail recursive for loop implementációnk tényleg nem tölti a stacket: írassuk ki vele mondjuk 1-től 20.000-ig a számokat!
  • Mi történik, ha a loop függvényünket meghívjuk? Miért? Próbáljuk ki, hogy Javában mi a helyzet egy ugyanilyen sémára épülő függénnyel.

Utolsó frissítés: 2020-12-22 21:04:26