Fix/simplify the generation of if/else statements
[fur] / normalization.py
1 import collections
2
3 import desugaring
4 import util
5
6 NormalVariableExpression = collections.namedtuple(
7     'NormalVariableExpression',
8     [
9         'variable',
10     ],
11 )
12
13 NormalIntegerLiteralExpression = collections.namedtuple(
14     'NormalIntegerLiteralExpression',
15     [
16         'integer',
17     ],
18 )
19
20 NormalLambdaExpression = collections.namedtuple(
21     'NormalLambdaExpression',
22     (
23         'name',
24         'argument_name_list',
25         'statement_list',
26     ),
27 )
28
29 NormalStringLiteralExpression = collections.namedtuple(
30     'NormalStringLiteralExpression',
31     [
32         'string',
33     ],
34 )
35
36 NormalSymbolExpression = collections.namedtuple(
37     'NormalSymbolExpression',
38     [
39         'symbol',
40     ],
41 )
42
43 NormalPushStatement = collections.namedtuple(
44     'NormalPushStatement',
45     (
46         'expression',
47     ),
48 )
49
50 NormalFunctionCallExpression = collections.namedtuple(
51     'NormalFunctionCallExpression',
52     [
53         'metadata',
54         'function_expression',
55         'argument_count',
56     ],
57 )
58
59 NormalArrayVariableInitializationStatement = collections.namedtuple(
60     'NormalArrayVariableInitializationStatement',
61     [
62         'variable',
63         'items',
64     ],
65 )
66
67 NormalSymbolArrayVariableInitializationStatement = collections.namedtuple(
68     'NormalSymbolArrayVariableInitializationStatement',
69     [
70         'variable',
71         'symbol_list',
72     ],
73 )
74
75 NormalVariableInitializationStatement = collections.namedtuple(
76     'NormalVariableInitializationStatement',
77     [
78         'variable',
79         'expression',
80     ],
81 )
82
83 NormalExpressionStatement = collections.namedtuple(
84     'NormalExpressionStatement',
85     [
86         'expression',
87     ],
88 )
89
90 NormalAssignmentStatement = collections.namedtuple(
91     'NormalAssignmentStatement',
92     [
93         'target',
94         'expression',
95     ],
96 )
97
98 NormalIfElseExpression = collections.namedtuple(
99     'NormalIfElseExpression',
100     [
101         'condition_expression',
102         'if_statement_list',
103         'else_statement_list',
104     ],
105 )
106
107 NormalProgram = collections.namedtuple(
108     'NormalProgram',
109     [
110         'statement_list',
111     ],
112 )
113
114 def normalize_integer_literal_expression(counter, expression):
115     return (
116         counter,
117         (),
118         NormalIntegerLiteralExpression(integer=expression.integer),
119     )
120
121 def normalize_lambda_expression(counter, expression):
122     variable = '${}'.format(counter)
123
124     _, statement_list = normalize_statement_list(
125         0,
126         expression.statement_list,
127     )
128
129     return (
130         counter + 1,
131         (
132             NormalVariableInitializationStatement(
133                 variable=variable,
134                 expression=NormalLambdaExpression(
135                     name=expression.name,
136                     argument_name_list=expression.argument_name_list,
137                     statement_list=statement_list,
138                 ),
139             ),
140         ),
141         NormalVariableExpression(variable=variable),
142     )
143
144 NormalListConstructExpression = collections.namedtuple(
145     'NormalListConstructExpression',
146     [
147         'allocate',
148     ],
149 )
150
151 NormalListAppendStatement = collections.namedtuple(
152     'NormalListAppendStatement',
153     [
154         'list_expression',
155         'item_expression',
156     ],
157 )
158
159 def normalize_list_literal_expression(counter, expression):
160     list_variable = '${}'.format(counter)
161     counter += 1
162
163     prestatements = [
164         NormalVariableInitializationStatement(
165             variable=list_variable,
166             expression=NormalListConstructExpression(allocate=len(expression.item_expression_list)),
167         ),
168     ]
169
170     list_expression = NormalVariableExpression(variable=list_variable)
171
172     for item_expression in expression.item_expression_list:
173         counter, item_expression_prestatements, normalized = normalize_expression(
174             counter,
175             item_expression,
176         )
177
178         for p in item_expression_prestatements:
179             prestatements.append(p)
180
181         prestatements.append(
182             NormalListAppendStatement(
183                 list_expression=list_expression,
184                 item_expression=normalized,
185             )
186         )
187
188     return (
189         counter,
190         tuple(prestatements),
191         list_expression,
192     )
193
194 def normalize_string_literal_expression(counter, expression):
195     return (
196         counter,
197         (),
198         NormalStringLiteralExpression(string=expression.string),
199     )
200
201 NormalStructureLiteralExpression = collections.namedtuple(
202     'NormalStructureLiteralExpression',
203     [
204         'field_count',
205         'symbol_list_variable',
206         'value_list_variable',
207     ],
208 )
209
210 def normalize_structure_literal_expression(counter, expression):
211     prestatements = []
212     field_symbol_array = []
213     field_value_array = []
214
215     for symbol_expression_pair in expression.fields:
216         counter, field_prestatements, field_expression = normalize_expression(
217             counter,
218             symbol_expression_pair.expression,
219         )
220
221         for p in field_prestatements:
222             prestatements.append(p)
223
224         field_symbol_array.append(symbol_expression_pair.symbol)
225         field_value_array.append(field_expression)
226
227     symbol_array_variable = '${}'.format(counter)
228     counter += 1
229
230     prestatements.append(
231         NormalSymbolArrayVariableInitializationStatement(
232             variable=symbol_array_variable,
233             symbol_list=tuple(field_symbol_array),
234         )
235     )
236
237     value_array_variable = '${}'.format(counter)
238     counter += 1
239
240     prestatements.append(
241         NormalArrayVariableInitializationStatement(
242             variable=value_array_variable,
243             items=tuple(field_value_array),
244         )
245     )
246
247     variable = '${}'.format(counter)
248
249     prestatements.append(
250         NormalVariableInitializationStatement(
251             variable=variable,
252             expression=NormalStructureLiteralExpression(
253                 field_count=len(expression.fields),
254                 symbol_list_variable=symbol_array_variable,
255                 value_list_variable=value_array_variable,
256             ),
257         )
258     )
259
260     return (
261         counter + 1,
262         tuple(prestatements),
263         NormalVariableExpression(variable=variable),
264     )
265
266
267 def normalize_symbol_expression(counter, expression):
268     return (
269         counter,
270         (),
271         NormalSymbolExpression(symbol=expression.symbol),
272     )
273
274 def normalize_function_call_expression(counter, expression):
275     prestatements = []
276
277     for argument in expression.argument_list:
278         counter, argument_prestatements, normalized_argument = normalize_expression(counter, argument)
279
280         for s in argument_prestatements:
281             prestatements.append(s)
282
283         prestatements.append(
284             NormalPushStatement(
285                 expression=normalized_argument,
286             ),
287         )
288
289     counter, function_prestatements, function_expression = normalize_expression(
290         counter,
291         expression.function,
292     )
293
294     for ps in function_prestatements:
295         prestatements.append(ps)
296
297     result_variable = '${}'.format(counter)
298
299     prestatements.append(
300         NormalVariableInitializationStatement(
301             variable=result_variable,
302             expression=NormalFunctionCallExpression(
303                 metadata=expression.metadata,
304                 function_expression=function_expression,
305                 argument_count=len(expression.argument_list),
306             ),
307         )
308     )
309
310     return (
311         counter + 1,
312         tuple(prestatements),
313         NormalVariableExpression(variable=result_variable),
314     )
315
316 def normalize_if_expression(counter, expression):
317     counter, condition_prestatements, condition_expression = normalize_expression(
318         counter,
319         expression.condition_expression,
320     )
321
322     counter, if_statement_list = normalize_statement_list(
323         counter,
324         expression.if_statement_list,
325     )
326     counter, else_statement_list = normalize_statement_list(
327         counter,
328         expression.else_statement_list,
329     )
330
331     return (
332         counter,
333         condition_prestatements,
334         NormalIfElseExpression(
335             condition_expression=condition_expression,
336             if_statement_list=if_statement_list,
337             else_statement_list=else_statement_list,
338         ),
339     )
340
341 def normalize_expression(counter, expression):
342     return {
343         desugaring.DesugaredFunctionCallExpression: normalize_function_call_expression,
344         desugaring.DesugaredIfExpression: normalize_if_expression,
345         desugaring.DesugaredIntegerLiteralExpression: normalize_integer_literal_expression,
346         desugaring.DesugaredLambdaExpression: normalize_lambda_expression,
347         desugaring.DesugaredListLiteralExpression: normalize_list_literal_expression,
348         desugaring.DesugaredStringLiteralExpression: normalize_string_literal_expression,
349         desugaring.DesugaredStructureLiteralExpression: normalize_structure_literal_expression,
350         desugaring.DesugaredSymbolExpression: normalize_symbol_expression,
351     }[type(expression)](counter, expression)
352
353 def normalize_expression_statement(counter, statement):
354     # TODO Normalized will be a NormalVariableExpression, which will go unused
355     # for expression statements in every case except when it's a return
356     # statement. This cases warnings on C compilation. We should only generate
357     # this variable when it will be used on return.
358     counter, prestatements, normalized = normalize_expression(counter, statement.expression)
359
360     return (
361         counter,
362         prestatements,
363         NormalExpressionStatement(expression=normalized),
364     )
365
366 def normalize_assignment_statement(counter, statement):
367     counter, prestatements, normalized_expression = normalize_expression(counter, statement.expression)
368     return (
369         counter,
370         prestatements,
371         NormalAssignmentStatement(
372             target=statement.target,
373             expression=normalized_expression,
374         ),
375     )
376
377 def normalize_statement(counter, statement):
378     return {
379         desugaring.DesugaredAssignmentStatement: normalize_assignment_statement,
380         desugaring.DesugaredExpressionStatement: normalize_expression_statement,
381     }[type(statement)](counter, statement)
382
383 @util.force_generator(tuple)
384 def normalize_statement_list(counter, statement_list):
385     result_statement_list = []
386
387     for statement in statement_list:
388         counter, prestatements, normalized = normalize_statement(counter, statement)
389         for s in prestatements:
390             result_statement_list.append(s)
391         result_statement_list.append(normalized)
392
393     return (
394         counter,
395         result_statement_list,
396     )
397
398 def normalize(program):
399     _, statement_list = normalize_statement_list(0, program.statement_list)
400
401     return NormalProgram(
402         statement_list=statement_list,
403     )