Mal als Gegenüberstellung ein AST wo man von den Knoten direkt die DATA-Werte abfragen kann und wo jeder Knoten selbst weiss wie er sich selbst auszuführen hat, und das gleiche dann nochmal mit komplett ”dummen” AST-Knoten und zwei Besuchern die jeweils die DATA-Werte abfragen oder einen AST ausführen:
”Schlaue” AST-Knoten:
Code: Alles auswählen
#!/usr/bin/env python3
import re
from operator import add, lt as is_less_than, sub as subtract
from attrs import define, field, frozen
from prettyprinter import cpprint as pprint, install_extras
install_extras(["attrs"])
SOURCE = """\
; Read in the data.
(READ n)
(DIM (t n n) (o n n))
(FOR i = 1 TO n
(FOR j = 1 TO i
(DO
(READ k)
((t i j) = k)
((o i j) = k))))
; Calculate the result.
(FOR i = n DOWNTO 2
(FOR j = 1 TO (i - 1)
(DO
(k = j)
(IF ((t i j) < (t i (j + 1))) THEN (k = (k + 1)))
((t (i - 1) j) = ((t (i - 1) j) + (t i k))))))
(PRINT (t 1 1))
(DATA 15) ; Row count to follow.
; The triangle values.
(DATA 75)
(DATA 95 64)
(DATA 17 47 82)
(DATA 18 35 87 10)
(DATA 20 4 82 47 65)
(DATA 19 1 23 75 3 34)
(DATA 88 2 77 73 7 63 67)
(DATA 99 65 4 28 6 16 70 92)
(DATA 41 41 26 56 83 40 80 70 33)
(DATA 41 48 72 33 47 32 37 16 94 29)
(DATA 53 71 44 65 25 43 91 52 97 51 14)
(DATA 70 11 33 28 77 73 17 78 39 68 17 57)
(DATA 91 71 52 38 17 14 91 43 58 50 27 29 48)
(DATA 63 66 4 68 89 53 67 30 73 16 69 87 40 31)
(DATA 4 62 98 27 23 9 70 98 73 93 38 53 60 4 23)
"""
@frozen
class Node:
def iter_data_values(self):
yield from ()
def execute(self, _context):
raise NotImplementedError
@frozen
class Literal(Node):
value = field()
def execute(self, _context):
return self.value
@frozen
class ScalarVariable(Node):
name = field()
def execute(self, context):
return context.scalars[self.name]
@frozen
class ArrayElement(Node):
name = field()
indices = field()
def execute(self, context):
result = context.arrays[self.name]
for index in self.indices:
result = result[index.execute(context)]
return result
@frozen
class BinaryOperation(Node):
SYMBOL_TO_FUNCTION = {"+": add, "-": subtract, "<": is_less_than}
left_operand = field()
operator = field()
right_operand = field()
def execute(self, context):
return self.SYMBOL_TO_FUNCTION[self.operator](
self.left_operand.execute(context),
self.right_operand.execute(context),
)
@frozen
class AssignScalar(Node):
variable = field()
expression = field()
def execute(self, context):
context.scalars[self.variable.name] = self.expression.execute(
context
)
@frozen
class AssignArrayElement(Node):
array_element = field()
expression = field()
def execute(self, context):
array = context.arrays[self.array_element.name]
for index in self.array_element.indices[:-1]:
array = array[index.execute(context)]
array[self.array_element.indices[-1].execute(context)] = (
self.expression.execute(context)
)
@frozen
class Data(Node):
values = field()
def iter_data_values(self):
return iter(self.values)
def execute(self, _context):
pass
@frozen
class Read(Node):
variable = field()
def execute(self, context):
context.scalars[self.variable.name] = next(context.data_values)
@frozen
class Dim(Node):
array_elements = field(factory=list)
def execute(self, context):
for array_element in self.array_elements:
dimensions = (
index.execute(context) + 1
for index in reversed(array_element.indices)
)
array = [0] * next(dimensions)
for dimension in dimensions:
array = [array.copy() for _ in range(dimension)]
context.arrays[array_element.name] = array
@frozen
class If(Node):
condition = field()
statement = field()
def iter_data_values(self):
return self.statement.iter_data_values()
def execute(self, context):
if self.condition.execute(context):
self.statement.execute(context)
@frozen
class For(Node):
variable = field()
start = field()
end = field()
step = field()
statement = field()
def iter_data_values(self):
return self.statement.iter_data_values()
def execute(self, context):
name = self.variable.name
context.scalars[name] = self.start.execute(context)
end_value = self.end.execute(context)
step_value = self.step.execute(context)
while True:
self.statement.execute(context)
value = context.scalars[name] + step_value
if (
step_value > 0
and value > end_value
or step_value < 0
and value < end_value
):
break
context.scalars[name] = value
@frozen
class Print(Node):
expressions = field(factory=list)
def execute(self, context):
print(
" ".join(
str(expression.execute(context))
for expression in self.expressions
)
)
@frozen
class Block(Node):
statements = field(factory=list)
def iter_data_values(self):
for statement in self.statements:
yield from statement.iter_data_values()
def execute(self, context):
for statement in self.statements:
statement.execute(context)
SCANNER = re.Scanner(
[
(r"\s+|;.*$", None),
(r"\(", "("),
(r"\)", ")"),
(r"-?\d+", lambda _, text: int(text)),
(r"\w+|[-+=<]", lambda _, text: text),
],
re.MULTILINE,
)
def parse(text):
tokens, text = SCANNER.scan(text)
if text:
raise ValueError(f"expected end of source text, found:\n{text}")
stack = []
statements = []
for token in tokens:
if token == "(":
stack.append(statements)
statements = []
elif token == ")":
if not stack:
raise ValueError("unexpected closing parenthesis")
stack[-1].append(statements)
statements = stack.pop()
else:
statements.append(token)
if stack:
raise ValueError(f"{len(stack)} unclosed parenthesis")
return statements[0] if len(statements) == 1 else ["DO", *statements]
def build_ast(s_expression):
match s_expression:
case "DO", *statements:
return Block(list(map(build_ast, statements)))
case "DIM", *arrays:
return Dim(list(map(build_ast, arrays)))
case (
"FOR",
str(name),
"=",
start,
"TO" | "DOWNTO" as direction,
end,
statement,
):
return For(
ScalarVariable(name),
build_ast(start),
build_ast(end),
Literal(1 if direction == "TO" else -1),
build_ast(statement),
)
case "IF", condition, "THEN", statement:
return If(build_ast(condition), build_ast(statement))
case "PRINT", *expressions:
return Print(list(map(build_ast, expressions)))
case "DATA", *values:
return Data(values)
case "READ", str(name):
return Read(ScalarVariable(name))
case [str(name), *indices], "=", expression:
return AssignArrayElement(
ArrayElement(name, list(map(build_ast, indices))),
build_ast(expression),
)
case str(name), "=", expression:
return AssignScalar(ScalarVariable(name), build_ast(expression))
case left_operand, "+" | "-" | "<" as operator, right_operand:
return BinaryOperation(
build_ast(left_operand), operator, build_ast(right_operand)
)
case str(name), *indices:
return ArrayElement(name, list(map(build_ast, indices)))
case str(name):
return ScalarVariable(name)
case int(value):
return Literal(value)
case _:
raise ValueError(f"can not handle {s_expression!r}")
@define
class Context:
data_values = field(factory=lambda: iter([]))
scalars = field(factory=dict)
arrays = field(factory=dict)
def run(ast):
context = Context(ast.iter_data_values())
ast.execute(context)
return context
def main():
s_expression = parse(SOURCE)
pprint(s_expression, indent=2, compact=True)
ast = build_ast(s_expression)
pprint(ast, indent=2, compact=True)
context = run(ast)
pprint(context.scalars, indent=2, compact=True)
pprint(context.arrays["t"], indent=2, compact=True)
if __name__ == "__main__":
main()
Code: Alles auswählen
#!/usr/bin/env python3
import re
from operator import add, lt as is_less_than, sub as subtract
from attrs import define, field, frozen
from prettyprinter import cpprint as pprint, install_extras
install_extras(["attrs"])
SOURCE = """\
; Read in the data.
(READ n)
(DIM (t n n) (o n n))
(FOR i = 1 TO n
(FOR j = 1 TO i
(DO
(READ k)
((t i j) = k)
((o i j) = k))))
; Calculate the result.
(FOR i = n DOWNTO 2
(FOR j = 1 TO (i - 1)
(DO
(k = j)
(IF ((t i j) < (t i (j + 1))) THEN (k = (k + 1)))
((t (i - 1) j) = ((t (i - 1) j) + (t i k))))))
(PRINT (t 1 1))
(DATA 15) ; Row count to follow.
; The triangle values.
(DATA 75)
(DATA 95 64)
(DATA 17 47 82)
(DATA 18 35 87 10)
(DATA 20 4 82 47 65)
(DATA 19 1 23 75 3 34)
(DATA 88 2 77 73 7 63 67)
(DATA 99 65 4 28 6 16 70 92)
(DATA 41 41 26 56 83 40 80 70 33)
(DATA 41 48 72 33 47 32 37 16 94 29)
(DATA 53 71 44 65 25 43 91 52 97 51 14)
(DATA 70 11 33 28 77 73 17 78 39 68 17 57)
(DATA 91 71 52 38 17 14 91 43 58 50 27 29 48)
(DATA 63 66 4 68 89 53 67 30 73 16 69 87 40 31)
(DATA 4 62 98 27 23 9 70 98 73 93 38 53 60 4 23)
"""
@frozen
class Literal:
value = field()
@frozen
class ScalarVariable:
name = field()
@frozen
class ArrayElement:
name = field()
indices = field()
@frozen
class BinaryOperation:
left_operand = field()
operator = field()
right_operand = field()
@frozen
class AssignScalar:
variable = field()
expression = field()
@frozen
class AssignArrayElement:
array_element = field()
expression = field()
@frozen
class Data:
values = field()
@frozen
class Read:
variable = field()
@frozen
class Dim:
array_elements = field(factory=list)
@frozen
class If:
condition = field()
statement = field()
@frozen
class For:
variable = field()
start = field()
end = field()
step = field()
statement = field()
@frozen
class Print:
expressions = field(factory=list)
@frozen
class Block:
statements = field(factory=list)
SCANNER = re.Scanner(
[
(r"\s+|;.*$", None),
(r"\(", "("),
(r"\)", ")"),
(r"-?\d+", lambda _, text: int(text)),
(r"\w+|[-+=<]", lambda _, text: text),
],
re.MULTILINE,
)
def parse(text):
tokens, text = SCANNER.scan(text)
if text:
raise ValueError(f"expected end of source text, found:\n{text}")
stack = []
statements = []
for token in tokens:
if token == "(":
stack.append(statements)
statements = []
elif token == ")":
if not stack:
raise ValueError("unexpected closing parenthesis")
stack[-1].append(statements)
statements = stack.pop()
else:
statements.append(token)
if stack:
raise ValueError(f"{len(stack)} unclosed parenthesis")
return statements[0] if len(statements) == 1 else ["DO", *statements]
def build_ast(s_expression):
match s_expression:
case "DO", *statements:
return Block(list(map(build_ast, statements)))
case "DIM", *arrays:
return Dim(list(map(build_ast, arrays)))
case (
"FOR",
str(name),
"=",
start,
"TO" | "DOWNTO" as direction,
end,
statement,
):
return For(
ScalarVariable(name),
build_ast(start),
build_ast(end),
Literal(1 if direction == "TO" else -1),
build_ast(statement),
)
case "IF", condition, "THEN", statement:
return If(build_ast(condition), build_ast(statement))
case "PRINT", *expressions:
return Print(list(map(build_ast, expressions)))
case "DATA", *values:
return Data(values)
case "READ", str(name):
return Read(ScalarVariable(name))
case [str(name), *indices], "=", expression:
return AssignArrayElement(
ArrayElement(name, list(map(build_ast, indices))),
build_ast(expression),
)
case str(name), "=", expression:
return AssignScalar(ScalarVariable(name), build_ast(expression))
case left_operand, "+" | "-" | "<" as operator, right_operand:
return BinaryOperation(
build_ast(left_operand), operator, build_ast(right_operand)
)
case str(name), *indices:
return ArrayElement(name, list(map(build_ast, indices)))
case str(name):
return ScalarVariable(name)
case int(value):
return Literal(value)
case _:
raise ValueError(f"can not handle {s_expression!r}")
@define
class Visitor:
def visit_any(self, _node):
return None
def visit(self, node):
return getattr(
self, "visit_" + node.__class__.__name__, self.visit_any
)(node)
@define
class DataVisitor(Visitor):
def visit_any(self, _node):
return iter([])
def visit_Data(self, node):
return iter(node.values)
def visit_Block(self, node):
for statement in node.statements:
yield from self.visit(statement)
def visit_For(self, node):
return self.visit(node.statement)
def visit_If(self, node):
return self.visit(node.statement)
@define
class ExecutionVisitor(Visitor):
SYMBOL_TO_FUNCTION = {"+": add, "-": subtract, "<": is_less_than}
data_values = field(factory=lambda: iter([]))
scalars = field(factory=dict)
arrays = field(factory=dict)
def visit_any(self, node):
raise RuntimeError(f"can not execute {node!r}")
def visit_Literal(self, node):
return node.value
def visit_ScalarVariable(self, node):
return self.scalars[node.name]
def visit_ArrayElement(self, node):
result = self.arrays[node.name]
for index in node.indices:
result = result[self.visit(index)]
return result
def visit_BinaryOperation(self, node):
return self.SYMBOL_TO_FUNCTION[node.operator](
self.visit(node.left_operand), self.visit(node.right_operand)
)
def visit_AssignScalar(self, node):
self.scalars[node.variable.name] = self.visit(node.expression)
def visit_AssignArrayElement(self, node):
array = self.arrays[node.array_element.name]
for index in node.array_element.indices[:-1]:
array = array[self.visit(index)]
array[self.visit(node.array_element.indices[-1])] = self.visit(
node.expression
)
def visit_Data(self, _node):
pass
def visit_Read(self, node):
self.scalars[node.variable.name] = next(self.data_values)
def visit_Dim(self, node):
for array_element in node.array_elements:
dimensions = (
self.visit(index) + 1
for index in reversed(array_element.indices)
)
array = [0] * next(dimensions)
for dimension in dimensions:
array = [array.copy() for _ in range(dimension)]
self.arrays[array_element.name] = array
def visit_If(self, node):
if self.visit(node.condition):
self.visit(node.statement)
def visit_For(self, node):
name = node.variable.name
self.scalars[name] = self.visit(node.start)
end_value = self.visit(node.end)
step_value = self.visit(node.step)
while True:
self.visit(node.statement)
value = self.scalars[name] + step_value
if (
step_value > 0
and value > end_value
or step_value < 0
and value < end_value
):
break
self.scalars[name] = value
def visit_Print(self, node):
print(
" ".join(
str(self.visit(expression)) for expression in node.expressions
)
)
def visit_Block(self, node):
for statement in node.statements:
self.visit(statement)
def run(ast):
executioner = ExecutionVisitor(DataVisitor().visit(ast))
executioner.visit(ast)
return executioner
def main():
s_expression = parse(SOURCE)
pprint(s_expression, indent=2, compact=True)
ast = build_ast(s_expression)
pprint(ast, indent=2, compact=True)
executioner = run(ast)
pprint(executioner.scalars, indent=2, compact=True)
pprint(executioner.arrays["t"], indent=2, compact=True)
if __name__ == "__main__":
main()