Remove unecessary pushing/popping for function calls
[fur] / normalization.py
1 import collections
2
3 import desugaring
4 import util
5
6 NormalVariableExpression = collections.namedtuple(
7     'NormalVariableExpression',
8     [
9         'variable',
10     ],
11 )
12
13 NormalIntegerLiteralExpression = collections.namedtuple(
14     'NormalIntegerLiteralExpression',
15     [
16         'integer',
17     ],
18 )
19
20 NormalLambdaExpression = collections.namedtuple(
21     'NormalLambdaExpression',
22     (
23         'name',
24         'argument_name_list',
25         'statement_list',
26     ),
27 )
28
29 NormalStringLiteralExpression = collections.namedtuple(
30     'NormalStringLiteralExpression',
31     [
32         'string',
33     ],
34 )
35
36 NormalSymbolExpression = collections.namedtuple(
37     'NormalSymbolExpression',
38     [
39         'symbol',
40     ],
41 )
42
43 NormalPushStatement = collections.namedtuple(
44     'NormalPushStatement',
45     (
46         'expression',
47     ),
48 )
49
50 NormalFunctionCallExpression = collections.namedtuple(
51     'NormalFunctionCallExpression',
52     [
53         'metadata',
54         'function_expression',
55         'argument_count',
56     ],
57 )
58
59 NormalArrayVariableInitializationStatement = collections.namedtuple(
60     'NormalArrayVariableInitializationStatement',
61     [
62         'variable',
63         'items',
64     ],
65 )
66
67 NormalSymbolArrayVariableInitializationStatement = collections.namedtuple(
68     'NormalSymbolArrayVariableInitializationStatement',
69     [
70         'variable',
71         'symbol_list',
72     ],
73 )
74
75 NormalVariableInitializationStatement = collections.namedtuple(
76     'NormalVariableInitializationStatement',
77     [
78         'variable',
79         'expression',
80     ],
81 )
82
83 NormalVariableReassignmentStatement = collections.namedtuple(
84     'NormalVariableReassignmentStatement',
85     [
86         'variable',
87         'expression',
88     ],
89 )
90
91 NormalExpressionStatement = collections.namedtuple(
92     'NormalExpressionStatement',
93     [
94         'expression',
95     ],
96 )
97
98 NormalAssignmentStatement = collections.namedtuple(
99     'NormalAssignmentStatement',
100     [
101         'target',
102         'expression',
103     ],
104 )
105
106 NormalIfElseStatement = collections.namedtuple(
107     'NormalIfElseStatement',
108     [
109         'condition_expression',
110         'if_statement_list',
111         'else_statement_list',
112     ],
113 )
114
115 NormalProgram = collections.namedtuple(
116     'NormalProgram',
117     [
118         'statement_list',
119     ],
120 )
121
122 def normalize_integer_literal_expression(counter, expression):
123     variable = '${}'.format(counter)
124     return (
125         counter + 1,
126         (
127             NormalVariableInitializationStatement(
128                 variable=variable,
129                 expression=NormalIntegerLiteralExpression(integer=expression.integer),
130             ),
131         ),
132         NormalVariableExpression(variable=variable),
133     )
134
135 def normalize_lambda_expression(counter, expression):
136     variable = '${}'.format(counter)
137
138     _, statement_list = normalize_statement_list(
139         0,
140         expression.statement_list,
141         assign_result_to='result',
142     )
143
144     return (
145         counter + 1,
146         (
147             NormalVariableInitializationStatement(
148                 variable=variable,
149                 expression=NormalLambdaExpression(
150                     name=expression.name,
151                     argument_name_list=expression.argument_name_list,
152                     statement_list=statement_list,
153                 ),
154             ),
155         ),
156         NormalVariableExpression(variable=variable),
157     )
158
159 NormalListConstructExpression = collections.namedtuple(
160     'NormalListConstructExpression',
161     [
162         'allocate',
163     ],
164 )
165
166 NormalListAppendStatement = collections.namedtuple(
167     'NormalListAppendStatement',
168     [
169         'list_expression',
170         'item_expression',
171     ],
172 )
173
174 def normalize_list_literal_expression(counter, expression):
175     list_variable = '${}'.format(counter)
176     counter += 1
177
178     prestatements = [
179         NormalVariableInitializationStatement(
180             variable=list_variable,
181             expression=NormalListConstructExpression(allocate=len(expression.item_expression_list)),
182         ),
183     ]
184
185     list_expression = NormalVariableExpression(variable=list_variable)
186
187     for item_expression in expression.item_expression_list:
188         counter, item_expression_prestatements, normalized = normalize_expression(
189             counter,
190             item_expression,
191         )
192
193         for p in item_expression_prestatements:
194             prestatements.append(p)
195
196         prestatements.append(
197             NormalListAppendStatement(
198                 list_expression=list_expression,
199                 item_expression=normalized,
200             )
201         )
202
203     return (
204         counter,
205         tuple(prestatements),
206         list_expression,
207     )
208
209 def normalize_string_literal_expression(counter, expression):
210     return (
211         counter,
212         (),
213         NormalStringLiteralExpression(string=expression.string),
214     )
215
216 NormalStructureLiteralExpression = collections.namedtuple(
217     'NormalStructureLiteralExpression',
218     [
219         'field_count',
220         'symbol_list_variable',
221         'value_list_variable',
222     ],
223 )
224
225 def normalize_structure_literal_expression(counter, expression):
226     prestatements = []
227     field_symbol_array = []
228     field_value_array = []
229
230     for symbol_expression_pair in expression.fields:
231         counter, field_prestatements, field_expression = normalize_expression(
232             counter,
233             symbol_expression_pair.expression,
234         )
235
236         for p in field_prestatements:
237             prestatements.append(p)
238
239         field_symbol_array.append(symbol_expression_pair.symbol)
240         field_value_array.append(field_expression)
241
242     symbol_array_variable = '${}'.format(counter)
243     counter += 1
244
245     prestatements.append(
246         NormalSymbolArrayVariableInitializationStatement(
247             variable=symbol_array_variable,
248             symbol_list=tuple(field_symbol_array),
249         )
250     )
251
252     value_array_variable = '${}'.format(counter)
253     counter += 1
254
255     prestatements.append(
256         NormalArrayVariableInitializationStatement(
257             variable=value_array_variable,
258             items=tuple(field_value_array),
259         )
260     )
261
262     variable = '${}'.format(counter)
263
264     prestatements.append(
265         NormalVariableInitializationStatement(
266             variable=variable,
267             expression=NormalStructureLiteralExpression(
268                 field_count=len(expression.fields),
269                 symbol_list_variable=symbol_array_variable,
270                 value_list_variable=value_array_variable,
271             ),
272         )
273     )
274
275     return (
276         counter + 1,
277         tuple(prestatements),
278         NormalVariableExpression(variable=variable),
279     )
280
281
282 def normalize_symbol_expression(counter, expression):
283     return (
284         counter,
285         (),
286         NormalSymbolExpression(symbol=expression.symbol),
287     )
288
289 def normalize_function_call_expression(counter, expression):
290     prestatements = []
291
292     for argument in expression.argument_list:
293         counter, argument_prestatements, normalized_argument = normalize_expression(counter, argument)
294
295         for s in argument_prestatements:
296             prestatements.append(s)
297
298         prestatements.append(
299             NormalPushStatement(
300                 expression=normalized_argument,
301             ),
302         )
303
304     counter, function_prestatements, function_expression = normalize_expression(
305         counter,
306         expression.function,
307     )
308
309     for ps in function_prestatements:
310         prestatements.append(ps)
311
312     result_variable = '${}'.format(counter)
313
314     prestatements.append(
315         NormalVariableInitializationStatement(
316             variable=result_variable,
317             expression=NormalFunctionCallExpression(
318                 metadata=expression.metadata,
319                 function_expression=function_expression,
320                 argument_count=len(expression.argument_list),
321             ),
322         )
323     )
324
325     return (
326         counter + 1,
327         tuple(prestatements),
328         NormalVariableExpression(variable=result_variable),
329     )
330
331 def normalize_if_expression(counter, expression):
332     counter, condition_prestatements, condition_expression = normalize_expression(
333         counter,
334         expression.condition_expression,
335     )
336
337     result_variable = '${}'.format(counter)
338     counter += 1
339
340     counter, if_statement_list = normalize_statement_list(
341         counter,
342         expression.if_statement_list,
343         assign_result_to=result_variable,
344     )
345     counter, else_statement_list = normalize_statement_list(
346         counter,
347         expression.else_statement_list,
348         assign_result_to=result_variable,
349     )
350
351     return (
352         counter,
353         condition_prestatements + (
354             NormalVariableInitializationStatement(
355                 variable=result_variable,
356                 expression=NormalVariableExpression(variable='builtin$nil'),
357             ),
358             NormalIfElseStatement(
359                 condition_expression=condition_expression,
360                 if_statement_list=if_statement_list,
361                 else_statement_list=else_statement_list,
362             ),
363         ),
364         NormalVariableExpression(variable=result_variable),
365     )
366
367 def normalize_expression(counter, expression):
368     return {
369         desugaring.DesugaredFunctionCallExpression: normalize_function_call_expression,
370         desugaring.DesugaredIfExpression: normalize_if_expression,
371         desugaring.DesugaredIntegerLiteralExpression: normalize_integer_literal_expression,
372         desugaring.DesugaredLambdaExpression: normalize_lambda_expression,
373         desugaring.DesugaredListLiteralExpression: normalize_list_literal_expression,
374         desugaring.DesugaredStringLiteralExpression: normalize_string_literal_expression,
375         desugaring.DesugaredStructureLiteralExpression: normalize_structure_literal_expression,
376         desugaring.DesugaredSymbolExpression: normalize_symbol_expression,
377     }[type(expression)](counter, expression)
378
379 def normalize_expression_statement(counter, statement):
380     # TODO Normalized will be a NormalVariableExpression, which will go unused
381     # for expression statements in every case except when it's a return
382     # statement. This cases warnings on C compilation. We should only generate
383     # this variable when it will be used on return.
384     counter, prestatements, normalized = normalize_expression(counter, statement.expression)
385
386     return (
387         counter,
388         prestatements,
389         NormalExpressionStatement(expression=normalized),
390     )
391
392 def normalize_assignment_statement(counter, statement):
393     counter, prestatements, normalized_expression = normalize_expression(counter, statement.expression)
394     return (
395         counter,
396         prestatements,
397         NormalAssignmentStatement(
398             target=statement.target,
399             expression=normalized_expression,
400         ),
401     )
402
403 def normalize_statement(counter, statement):
404     return {
405         desugaring.DesugaredAssignmentStatement: normalize_assignment_statement,
406         desugaring.DesugaredExpressionStatement: normalize_expression_statement,
407     }[type(statement)](counter, statement)
408
409 @util.force_generator(tuple)
410 def normalize_statement_list(counter, statement_list, **kwargs):
411     assign_result_to = kwargs.pop('assign_result_to', None)
412
413     assert len(kwargs) == 0
414
415     result_statement_list = []
416
417     for statement in statement_list:
418         counter, prestatements, normalized = normalize_statement(counter, statement)
419         for s in prestatements:
420             result_statement_list.append(s)
421         result_statement_list.append(normalized)
422
423     # TODO The way we fix the last statement is really confusing
424     last_statement = result_statement_list[-1]
425
426     if isinstance(last_statement, NormalExpressionStatement) and isinstance(last_statement.expression, NormalVariableExpression):
427         if assign_result_to is not None:
428             result_expression = result_statement_list.pop().expression
429             result_statement_list.append(
430                 NormalVariableReassignmentStatement(
431                     variable=assign_result_to,
432                     expression=result_expression,
433                 )
434             )
435
436     return (
437         counter,
438         result_statement_list,
439     )
440
441 def normalize(program):
442     _, statement_list = normalize_statement_list(0, program.statement_list)
443
444     return NormalProgram(
445         statement_list=statement_list,
446     )