Continuation Passing Style + Tail Call Opt. via Trampolining

Code-Stücke können hier veröffentlicht werden.
Antworten
Benutzeravatar
pillmuncher
User
Beiträge: 1482
Registriert: Samstag 21. März 2009, 22:59
Wohnort: Pfaffenwinkel

Seit einiger Zeit bastel ich mir einen Spielzeug-Prolog-Interpreter, und da ich mir keine WAM bauen wollte (mangels Ahnung), mach ich es ganz naiv-rekursiv mittels Robinson-Algorithmus. Die max. Rekursionstiefe wirft allerdings bei komplexen Prolog-Programmen ein Problem auf. Also dacht ich, bau ich mir selber eine TCO und verwende CPS. Das ganze sieht dann ungefähr so aus:

Erstmal eine kleine Datenstruktur zum Testen und ein kleiner rekursiver Beispiel-Algorithmus, der der Struktur vom Robinson recht nahe kommt (jedoch ohne Resolving und Unifikation, und alles noch ohne TCO/CPS):

Code: Alles auswählen

from random import randint

class Node(object):
    def __init__(self, value, children):
        self.value = value
        self.children = children
    def __repr__(self):
        return "%d" % self.value

def maketrees(n, *ns):
    if ns:
        return [Node(randint(1,9), maketrees(*ns)) for each in xrange(n)]
    return [Node(randint(1,9), []) for each in xrange(n)]

def validpaths(nodes, path=[]):
    if nodes == []:
        yield path
    else:
        for node in nodes:
            if isvalid(node):
                for each in validpaths(node.children, path + [node]):
                    yield each

def isvalid(node):
    return node.value > 3 

for path in validpaths(maketrees(3,2,5,4)):
    print path
validpaths liefert alle Pfade von oben nach unten, deren Knoten alle einen value > 3 haben. Man stelle sich die Routen-Planung eines Schwertransports vor, der drei Meter breit ist und stelle sich vor, die oben generierten Bäume repräsentierten Straßen, die value Meter breit sind. Dann zeigt die Ausgabe des Algorithmus alle möglichen Routen mit den jeweilingen Werten von value an, ungefähr so:

Code: Alles auswählen

[6, 8, 9, 7]
[6, 8, 9, 7]
[6, 8, 9, 8]
[6, 8, 9, 6]
[9, 4, 5, 8]
[9, 4, 5, 9]
[9, 4, 5, 9]
[9, 4, 8, 8]
Die TCO/CPS Version verlangt nun nach einem Trampolin:

Code: Alles auswählen

def trampoline(bouncing, *args, **kwargs):
    while bouncing:
        throwing, bouncing, args, kwargs = bouncing(*args, **kwargs)
        if throwing:
            yield throwing()

def bouncy(f):
    return lambda *args, **kwargs:(None, f, args, kwargs)

def land(*args, **kwargs):
    return None, None, args, kwargs

def bounce(function, *args, **kwargs):
    return None, function, args, kwargs

def throw(function, thrown, *args, **kwargs):
    return lambda:thrown, function, args, kwargs
und das angepasste validpaths sieht dann so aus:

Code: Alles auswählen

@bouncy
def validpathsCT(nodes, path=[], result=throw, cont=land):
    if nodes == []:
        return result(cont, path)
    def trail(*unused):
        return validpathsCT(nodes[1:], path, cont, cont)
    node = nodes[0]
    if isvalid(node):
        return validpathsCT(node.children, path + [node], throw, trail)
    return trail()

def validpaths(trees):
    return trampoline(validpathsCT, trees)
Der Aufruf geschieht genauso wie oben, und beide Versionen erzeugen bei gleichen Ausgangsdaten dieselben Ergebnisse.

Hier ist die Schleife von oben (for node in nodes) durch einen weiteren rekursiven Aufruf ersetzt worden (return trail()). trail ist eine Closure, in der zweimal die zweite der beiden aktuellen Continuations (cont) weitergereicht wird. Beim rekursiven Abstieg dagegen werden als erste Continuation die Rückgabe-Funktion throw (für den Fall, dass ein gültiger Pfad gefunden wurde) übergeben, und als zweite die Backtracking-Funktion trail (für den Fall, dass es von hier aus keinen solchen Pfad gibt). Der eine Trick ist nun, dass in trail() cont == trail selbst ist, sofern node kein Root-Knoten ist, und das heisst, dass, wenn result ebenfalls == trail ist, return result(cont, value) "zurückkehrt" und der Algorithmus per Backtracking dort weitermacht, wo er soll, solange, bis er am Schluss "land"et. Oder, result != trail, dann ist aber result == throw und cont == trail. result ruft nun cont auf (das, wie gesagt, == trail ist), wodurch dann alle Lösungen generiert werden. trail wird also zweitverwertet, einmal als Abbruchs-Funktion nach einem Scheitern, und das andere mal als Fortsetzungs-Funktion nach der Generierung einer Route.

[edit]
Ich weiß nicht, ob die vorhergehende Erklärung was taugt, deswegen nochmal anders:

Es gibt vier mögliche Kombinationen von result und cont:

a) result == throw und cont == land. Das besteht nur beim allerersten Aufruf.
b) result == throw und cont == trail. Das besteht nur direkt beim rekursiven Abstieg.
c) result == trail und cont == trail. Das besteht beim Trailing, außer wenn
d) result == land und cont == land, wenn das Trailing entlang der rechten Außenkante verläuft.

Da einzig throw eine Ausgabe erzeugt, und nur trail die Abarbeitung fortsetzt, geschieht also immer folgendes, wenn nodes == [] erreicht ist:

Bei a) wird nochmal eine Ausgabe erzeugt und mit land fortgesetzt. (Termination)
bei b) wird eine Ausgabe erzeugt, aber mit trail fortgesezt. (Backtracking after success)
bei c) wird keine Ausgabe erzeugt, und cont wird ignoriert. Da result aber == trail ist, wird sowieso mit trail fortgesetzt. (Backtracking after failure)
bei d) der Algorithmus terminiert ohne weitere Ausgabe.
[/edit]

Lustig ist noch folgendes: wenn man @bouncy weglässt läuft das ganze immer noch, nur nicht mehr tail call optimized, was man einfach feststellen kann, wenn man eine Exception schmeisst:

Code: Alles auswählen

########## HIER NIX @bouncy
def validpathsCT(nodes, path=[], result=throw, cont=land):
    if nodes == []:
        raise Exception, "w/o tail call optimization"
        ...
Ergebnis:

Code: Alles auswählen

$ python bouncing.py
Traceback (most recent call last):
  File "bouncing.py", line 185, in <module>
    for path in validpaths(maketrees(3,2,3,4)):
  File "bouncing.py", line 145, in trampoline
    throwing, bouncing, args, kwargs = bouncing(*args, **kwargs)
  File "bouncing.py", line 175, in validpathsCT
    return trail()
  File "bouncing.py", line 171, in trail
    return validpathsCT(nodes[1:], path, cont, cont)
  File "bouncing.py", line 175, in validpathsCT
    return trail()
  File "bouncing.py", line 171, in trail
    return validpathsCT(nodes[1:], path, cont, cont)
  File "bouncing.py", line 174, in validpathsCT
    return validpathsCT(node.children, path + [node], throw, trail)
  File "bouncing.py", line 174, in validpathsCT
    return validpathsCT(node.children, path + [node], throw, trail)
  File "bouncing.py", line 174, in validpathsCT
    return validpathsCT(node.children, path + [node], throw, trail)
  File "bouncing.py", line 174, in validpathsCT
    return validpathsCT(node.children, path + [node], throw, trail)
  File "bouncing.py", line 168, in validpathsCT
    raise Exception, "w/o tail call optimization"
Exception: w/o tail call optimization
und zum Vergleich:

Code: Alles auswählen

@bouncy
def validpathsCT(nodes, path=[], result=throw, cont=land):
    if nodes == []:
        raise Exception, "with tail call optimization"
        ...

Code: Alles auswählen

$ python bouncing.py
Traceback (most recent call last):
  File "bouncing.py", line 185, in <module>
    for path in validpaths(maketrees(3,2,3,4)):
  File "bouncing.py", line 145, in trampoline
    throwing, bouncing, args, kwargs = bouncing(*args, **kwargs)
  File "bouncing.py", line 168, in validpathsCT
    raise Exception, "with tail call optimization"
Exception: with tail call optimization
Was haltet ihr davon?

(und wer baut mir dazu ein call/cc? ;-) )

Gruß und "May You Bounce In Peace!" (Brain Ball #1, Futurama, "War Is The H-Word"),
Mick.
sma
User
Beiträge: 3018
Registriert: Montag 19. November 2007, 19:57
Wohnort: Kiel

Interessant. Ich muss mich dem jedoch schrittweise mit einem einfacheren Beispiel nähern.

Hier ist die klassische Fakultätsfunktion:

Code: Alles auswählen

def fac(n):
    if n == 0:
        return 1
    return fac(n - 1) * n
Diese kann ich umschreiben, dass nur Endrekursion benutzt wird:

Code: Alles auswählen

def fac_tc(n, acc):
    if n == 0:
        return acc
    return fac_tc(n - 1, n * acc)
Ich beherrsche auch CPS. Dazu reiche ich in die Funktion eine Funktion hinein, die aufgerufen wird, wenn es weitergehen soll. CPS nutzt (automatisch?) Endrekursion und man kann CPS relativ einfach aus der endrekursiven Form ableiten.

Code: Alles auswählen

def fac_cps(n, next):
    if n == 0:
        return next(1)
    return fac_cps(n - 1, lambda r: next(r * n))
Doch beide Varianten sind immer noch rekursiv und damit durch den Python-Stack beschränkt. Ich nutze stattdessen eine while-Schleife, die die Funktionsaufrufe treibt:

Code: Alles auswählen

def driver(func, n, a):
    while func:
        func, n, a = func(n, a)
    return a

def fac_d(n, a):
    if n == 0:
        return None, None, a
    return fac_d, n - 1, a * n

print driver(fac_d, 10, 1)
Ich übergebe dem "Treiber" die aufzurufende Funktion, das Argument und den Akkumulator für das Ergebnis. Die übergebene Funktion muss ein 3-Tupel mit der nächsten aufzurufenden Funktion, deren Argument und dem neuen Akkumulatorwert liefern. Das kann man natürlich auch für mehr als nur ein Argument implementieren.

So habe ich das dann auf trampoline/bouncy abgebildet:

Code: Alles auswählen

@bouncy
def fac_t(n, a):
    if n == 0:
        return throw(land, a)
    return fac_t(n - 1, a * n)

for r in trampoline(fac_t, 10, 1):
    print r
Stefan
Benutzeravatar
pillmuncher
User
Beiträge: 1482
Registriert: Samstag 21. März 2009, 22:59
Wohnort: Pfaffenwinkel

Hallo Stefan.

Dein Beispiel ist besser als meines, um den Sinn der ganzen TCO/Trampolining-Veranstaltung zu veranschaulichen. Man versuche nur mal das @bouncy bei fac_t wegzulassen es mit n=10000 aufzurufen:

Code: Alles auswählen

########## HIER NIX @bouncy
def fac_t(n, a):
    if n == 0:
        return throw(land, a)
    return fac_t(n - 1, a * n)

for r in trampoline(fac_t, 10000, 1):
    print r
Im Vergleich dazu ist das Ergebnis desselben Aufrufs mit @bouncy irgendwie besser...

Wie du zeigst, lohnt es nicht, Algorithmen wie die Fakultäts-Berechnung auf CPS umzubauen. Deswegen mein doch recht kompliziertes Beispiel mit der doppelten Rekursion und komplexen Fallunterscheidungen beim Trampolining, wo CPS zu leuchten beginnt.

Übrigens, für alle, denen CPS fremd ist: wenn man Continuations verstehen will, ist Denys Duchiers Artikel "Continuations Made Simple and Illustrated" mitsamt dem dort gelinkten CPS-Prolog in Python echt lesenswert.

Ein anderer Punkt, der mir aufgefallen ist: als ich trampoline(...) gebaut habe, hatte ich nur All-Solutions-Suche im Sinn. Bei fac_t, wo nur ein einzelner Wert berechnet wird, sieht die for-Schleife irgendwie bescheuert aus. Auf die Schnelle fällt mir nichts ein, außer dass man dann eben zwei Trampoline haben müsste, trampoline und itrampoline, oder xtrampoline, oder whatever. Oder gibt es eine Lösung die may not be obvious at first unless you're Dutch?

Gruß,
Mick.
Antworten