Normalize function expressions
[fur] / normalization.py
1 import collections
2
3 import parsing
4 import util
5
6 NormalVariableExpression = collections.namedtuple(
7     'NormalVariableExpression',
8     [
9         'variable',
10     ],
11 )
12
13 NormalIntegerLiteralExpression = collections.namedtuple(
14     'NormalIntegerLiteralExpression',
15     [
16         'integer',
17     ],
18 )
19
20 NormalStringLiteralExpression = collections.namedtuple(
21     'NormalStringLiteralExpression',
22     [
23         'string',
24     ],
25 )
26
27 NormalSymbolExpression = collections.namedtuple(
28     'NormalSymbolExpression',
29     [
30         'symbol',
31     ],
32 )
33
34 NormalNegationExpression = collections.namedtuple(
35     'NormalNegationExpression',
36     [
37         'internal_expression',
38     ],
39 )
40
41 NormalInfixExpression = collections.namedtuple(
42     'NormalInfixExpression',
43     [
44         'order',
45         'operator',
46         'left',
47         'right',
48     ],
49 )
50
51 NormalFunctionCallExpression = collections.namedtuple(
52     'NormalFunctionCallExpression',
53     [
54         'function_expression',
55         'argument_count',
56         'argument_items',
57     ],
58 )
59
60 NormalArrayVariableInitializationStatement = collections.namedtuple(
61     'NormalArrayVariableInitializationStatement',
62     [
63         'variable',
64         'items',
65     ],
66 )
67
68 NormalVariableInitializationStatement = collections.namedtuple(
69     'NormalVariableInitializationStatement',
70     [
71         'variable',
72         'expression',
73     ],
74 )
75
76 NormalVariableReassignmentStatement = collections.namedtuple(
77     'NormalVariableReassignmentStatement',
78     [
79         'variable',
80         'expression',
81     ],
82 )
83
84 NormalExpressionStatement = collections.namedtuple(
85     'NormalExpressionStatement',
86     [
87         'expression',
88     ],
89 )
90
91 NormalAssignmentStatement = collections.namedtuple(
92     'NormalAssignmentStatement',
93     [
94         'target',
95         'expression',
96     ],
97 )
98
99 NormalIfElseStatement = collections.namedtuple(
100     'NormalIfElseStatement',
101     [
102         'condition_expression',
103         'if_statements',
104         'else_statements',
105     ],
106 )
107
108 NormalFunctionDefinitionStatement = collections.namedtuple(
109     'NormalFunctionDefinitionStatement',
110     [
111         'name',
112         'argument_name_list',
113         'statement_list',
114     ],
115 )
116
117 NormalProgram = collections.namedtuple(
118     'NormalProgram',
119     [
120         'statement_list',
121     ],
122 )
123
124 def fake_normalization(counter, thing):
125     return (counter, (), thing)
126
127 def normalize_integer_literal_expression(counter, expression):
128     # TODO Store this in a C variable
129     return (
130         counter,
131         (),
132         NormalIntegerLiteralExpression(integer=expression.integer),
133     )
134
135 def normalize_string_literal_expression(counter, expression):
136     # TODO Store this in a C variable
137     return (
138         counter,
139         (),
140         NormalStringLiteralExpression(string=expression.string),
141     )
142
143 def normalize_symbol_expression(counter, expression):
144     variable = '${}'.format(counter)
145     return (
146         counter + 1,
147         (NormalVariableInitializationStatement(
148             variable=variable,
149             expression=NormalSymbolExpression(symbol=expression.symbol),
150         ),),
151         NormalVariableExpression(variable=variable),
152     )
153
154 def normalize_function_call_expression(counter, expression):
155     assert isinstance(expression, parsing.FurFunctionCallExpression)
156
157     prestatements = []
158     arguments = []
159
160     for argument in expression.arguments:
161         counter, argument_prestatements, normalized_argument = normalize_expression(counter, argument)
162
163         for s in argument_prestatements:
164             prestatements.append(s)
165
166         variable = '${}'.format(counter)
167         prestatements.append(NormalVariableInitializationStatement(
168             variable=variable,
169             expression=normalized_argument,
170         ))
171         arguments.append(NormalVariableExpression(
172             variable=variable,
173         ))
174         counter += 1
175
176     arguments_variable = '${}'.format(counter)
177     counter += 1
178
179     prestatements.append(NormalArrayVariableInitializationStatement(
180         variable=arguments_variable,
181         items=tuple(arguments),
182     ))
183
184     counter, function_prestatements, function_expression = normalize_expression(
185         counter,
186         expression.function,
187     )
188
189     for ps in function_prestatements:
190         prestatements.append(ps)
191
192     if not isinstance(function_expression, NormalVariableExpression):
193         function_variable = '${}'.format(counter)
194
195         prestatements.append(NormalVariableInitializationStatement(
196             variable=function_variable,
197             expression=function_expression,
198         ))
199
200         function_expression = NormalVariableExpression(variable=function_variable)
201         counter += 1
202
203     return (
204         counter,
205         tuple(prestatements),
206         NormalFunctionCallExpression(
207             function_expression=function_expression,
208             argument_count=len(arguments),
209             argument_items=NormalVariableExpression(variable=arguments_variable),
210         ),
211     )
212
213 def normalize_basic_infix_operation(counter, expression):
214     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left)
215     counter, right_prestatements, right_expression = normalize_expression(counter, expression.right)
216
217     left_variable = '${}'.format(counter)
218     counter += 1
219     right_variable = '${}'.format(counter)
220     counter += 1
221
222     root_prestatements = (
223         NormalVariableInitializationStatement(
224             variable=left_variable,
225             expression=left_expression,
226         ),
227         NormalVariableInitializationStatement(
228             variable=right_variable,
229             expression=right_expression,
230         ),
231     )
232
233     return (
234         counter,
235         left_prestatements + right_prestatements + root_prestatements,
236         NormalInfixExpression(
237             order=expression.order,
238             operator=expression.operator,
239             left=NormalVariableExpression(variable=left_variable),
240             right=NormalVariableExpression(variable=right_variable),
241         ),
242     )
243
244 def normalize_comparison_expression(counter, expression):
245     stack = []
246
247     while isinstance(expression.left, parsing.FurInfixExpression) and expression.order == 'comparison_level':
248         stack.append((expression.operator, expression.order, expression.right))
249         expression = expression.left
250
251     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left)
252     counter, right_prestatements, right_expression = normalize_expression(counter, expression.right)
253
254     left_variable = '${}'.format(counter)
255     counter += 1
256     right_variable = '${}'.format(counter)
257     counter += 1
258
259     root_prestatements = (
260         NormalVariableInitializationStatement(
261             variable=left_variable,
262             expression=left_expression,
263         ),
264         NormalVariableInitializationStatement(
265             variable=right_variable,
266             expression=right_expression,
267         ),
268     )
269
270     counter, result_prestatements, result_expression = (
271         counter,
272         left_prestatements + right_prestatements + root_prestatements,
273         NormalInfixExpression(
274             order=expression.order,
275             operator=expression.operator,
276             left=NormalVariableExpression(variable=left_variable),
277             right=NormalVariableExpression(variable=right_variable),
278         ),
279     )
280
281     while len(stack) > 0:
282         right_operator, right_order, right_expression = stack.pop()
283         and_right_expression = parsing.FurInfixExpression(
284             operator=right_operator,
285             order=right_order,
286             left=NormalVariableExpression(variable=right_variable),
287             right=right_expression,
288         )
289
290         and_expression = parsing.FurInfixExpression(
291             operator='and',
292             order='and_level',
293             left=result_expression,
294             right=and_right_expression,
295         )
296
297         counter, and_prestatements, result_expression = normalize_boolean_expression(
298             counter,
299             and_expression,
300         )
301
302         result_prestatements = result_prestatements + and_prestatements
303
304     return (counter, result_prestatements, result_expression)
305
306 def normalize_boolean_expression(counter, expression):
307     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left)
308     counter, right_prestatements, right_expression = normalize_expression(counter, expression.right)
309
310     result_variable = '${}'.format(counter)
311     if_else_prestatment = NormalVariableInitializationStatement(variable=result_variable, expression=left_expression)
312     counter += 1
313
314     condition_expression=NormalVariableExpression(variable=result_variable)
315     short_circuited_statements = right_prestatements + (NormalVariableReassignmentStatement(variable=result_variable, expression=right_expression),)
316
317     if expression.operator == 'and':
318         if_else_statement = NormalIfElseStatement(
319             condition_expression=condition_expression,
320             if_statements=short_circuited_statements,
321             else_statements=(),
322         )
323
324     elif expression.operator == 'or':
325         if_else_statement = NormalIfElseStatement(
326             condition_expression=condition_expression,
327             if_statements=(),
328             else_statements=short_circuited_statements,
329         )
330
331     else:
332         raise Exception('Unable to handle operator "{}"'.format(expression.operator))
333
334     return (
335         counter,
336         left_prestatements + (if_else_prestatment, if_else_statement),
337         NormalVariableExpression(variable=result_variable),
338     )
339
340
341 def normalize_infix_expression(counter, expression):
342     return {
343         'multiplication_level': normalize_basic_infix_operation,
344         'addition_level': normalize_basic_infix_operation,
345         'comparison_level': normalize_comparison_expression,
346         'and_level': normalize_boolean_expression,
347         'or_level': normalize_boolean_expression,
348     }[expression.order](counter, expression)
349
350 def normalize_negation_expression(counter, expression):
351     counter, prestatements, internal_expression = normalize_expression(counter, expression.value)
352
353     internal_variable = '${}'.format(counter)
354     counter += 1
355
356     return (
357         counter,
358         prestatements + (NormalVariableInitializationStatement(variable=internal_variable, expression=internal_expression),),
359         NormalNegationExpression(internal_expression=NormalVariableExpression(variable=internal_variable)),
360     )
361
362 def normalize_expression(counter, expression):
363     return {
364         NormalInfixExpression: fake_normalization,
365         NormalVariableExpression: fake_normalization,
366         parsing.FurFunctionCallExpression: normalize_function_call_expression,
367         parsing.FurInfixExpression: normalize_infix_expression,
368         parsing.FurIntegerLiteralExpression: normalize_integer_literal_expression,
369         parsing.FurNegationExpression: normalize_negation_expression,
370         parsing.FurStringLiteralExpression: normalize_string_literal_expression,
371         parsing.FurSymbolExpression: normalize_symbol_expression,
372     }[type(expression)](counter, expression)
373
374 def normalize_expression_statement(counter, statement):
375     # TODO Normalized will be a NormalVariableExpression, which will go unused
376     # for expression statements in every case except when it's a return
377     # statement. This cases warnings on C compilation. We should only generate
378     # this variable when it will be used on return.
379     counter, prestatements, normalized = normalize_expression(counter, statement.expression)
380
381     return (
382         counter,
383         prestatements,
384         NormalExpressionStatement(expression=normalized),
385     )
386
387 def normalize_function_definition_statement(counter, statement):
388     return (
389         counter,
390         (),
391         NormalFunctionDefinitionStatement(
392             name=statement.name,
393             argument_name_list=statement.argument_name_list,
394             statement_list=normalize_statement_list(statement.statement_list),
395         ),
396     )
397
398 def normalize_assignment_statement(counter, statement):
399     counter, prestatements, normalized_expression = normalize_expression(counter, statement.expression)
400     return (
401         counter,
402         prestatements,
403         NormalAssignmentStatement(
404             target=statement.target,
405             expression=normalized_expression,
406         ),
407     )
408
409 def normalize_statement(counter, statement):
410     return {
411         parsing.FurAssignmentStatement: normalize_assignment_statement,
412         parsing.FurExpressionStatement: normalize_expression_statement,
413         parsing.FurFunctionDefinitionStatement: normalize_function_definition_statement,
414     }[type(statement)](counter, statement)
415
416 @util.force_generator(tuple)
417 def normalize_statement_list(statement_list):
418     counter = 0
419
420     for statement in statement_list:
421         counter, prestatements, normalized = normalize_statement(counter, statement)
422         for s in prestatements:
423             yield s
424         yield normalized
425
426 def normalize(program):
427     return NormalProgram(
428         statement_list=normalize_statement_list(program.statement_list),
429     )