b55d42c696ef7bf2d8aa040ddd6f13c297b0629f
[fur] / normalization.py
1 import collections
2
3 import desugaring
4 import util
5
6 NormalVariableExpression = collections.namedtuple(
7     'NormalVariableExpression',
8     [
9         'variable',
10     ],
11 )
12
13 NormalIntegerLiteralExpression = collections.namedtuple(
14     'NormalIntegerLiteralExpression',
15     [
16         'integer',
17     ],
18 )
19
20 NormalLambdaExpression = collections.namedtuple(
21     'NormalLambdaExpression',
22     (
23         'name',
24         'argument_name_list',
25         'statement_list',
26     ),
27 )
28
29 NormalStringLiteralExpression = collections.namedtuple(
30     'NormalStringLiteralExpression',
31     [
32         'string',
33     ],
34 )
35
36 NormalSymbolExpression = collections.namedtuple(
37     'NormalSymbolExpression',
38     [
39         'symbol',
40     ],
41 )
42
43 NormalPushStatement = collections.namedtuple(
44     'NormalPushStatement',
45     (
46         'expression',
47     ),
48 )
49
50 NormalFunctionCallExpression = collections.namedtuple(
51     'NormalFunctionCallExpression',
52     [
53         'metadata',
54         'function_expression',
55         'argument_count',
56     ],
57 )
58
59 NormalArrayVariableInitializationStatement = collections.namedtuple(
60     'NormalArrayVariableInitializationStatement',
61     [
62         'variable',
63         'items',
64     ],
65 )
66
67 NormalSymbolArrayVariableInitializationStatement = collections.namedtuple(
68     'NormalSymbolArrayVariableInitializationStatement',
69     [
70         'variable',
71         'symbol_list',
72     ],
73 )
74
75 NormalVariableInitializationStatement = collections.namedtuple(
76     'NormalVariableInitializationStatement',
77     [
78         'variable',
79         'expression',
80     ],
81 )
82
83 NormalVariableReassignmentStatement = collections.namedtuple(
84     'NormalVariableReassignmentStatement',
85     [
86         'variable',
87         'expression',
88     ],
89 )
90
91 NormalExpressionStatement = collections.namedtuple(
92     'NormalExpressionStatement',
93     [
94         'expression',
95     ],
96 )
97
98 NormalAssignmentStatement = collections.namedtuple(
99     'NormalAssignmentStatement',
100     [
101         'target',
102         'expression',
103     ],
104 )
105
106 NormalIfElseStatement = collections.namedtuple(
107     'NormalIfElseStatement',
108     [
109         'condition_expression',
110         'if_statement_list',
111         'else_statement_list',
112     ],
113 )
114
115 NormalProgram = collections.namedtuple(
116     'NormalProgram',
117     [
118         'statement_list',
119     ],
120 )
121
122 def normalize_integer_literal_expression(counter, expression):
123     variable = '${}'.format(counter)
124     return (
125         counter + 1,
126         (
127             NormalVariableInitializationStatement(
128                 variable=variable,
129                 expression=NormalIntegerLiteralExpression(integer=expression.integer),
130             ),
131         ),
132         NormalVariableExpression(variable=variable),
133     )
134
135 def normalize_lambda_expression(counter, expression):
136     variable = '${}'.format(counter)
137
138     _, statement_list = normalize_statement_list(
139         0,
140         expression.statement_list,
141         assign_result_to='result',
142     )
143
144     return (
145         counter + 1,
146         (
147             NormalVariableInitializationStatement(
148                 variable=variable,
149                 expression=NormalLambdaExpression(
150                     name=expression.name,
151                     argument_name_list=expression.argument_name_list,
152                     statement_list=statement_list,
153                 ),
154             ),
155         ),
156         NormalVariableExpression(variable=variable),
157     )
158
159 NormalListConstructExpression = collections.namedtuple(
160     'NormalListConstructExpression',
161     [
162         'allocate',
163     ],
164 )
165
166 NormalListAppendStatement = collections.namedtuple(
167     'NormalListAppendStatement',
168     [
169         'list_expression',
170         'item_expression',
171     ],
172 )
173
174 def normalize_list_literal_expression(counter, expression):
175     list_variable = '${}'.format(counter)
176     counter += 1
177
178     prestatements = [
179         NormalVariableInitializationStatement(
180             variable=list_variable,
181             expression=NormalListConstructExpression(allocate=len(expression.item_expression_list)),
182         ),
183     ]
184
185     list_expression = NormalVariableExpression(variable=list_variable)
186
187     for item_expression in expression.item_expression_list:
188         counter, item_expression_prestatements, normalized = normalize_expression(
189             counter,
190             item_expression,
191         )
192
193         for p in item_expression_prestatements:
194             prestatements.append(p)
195
196         prestatements.append(
197             NormalListAppendStatement(
198                 list_expression=list_expression,
199                 item_expression=normalized,
200             )
201         )
202
203     return (
204         counter,
205         tuple(prestatements),
206         list_expression,
207     )
208
209 def normalize_string_literal_expression(counter, expression):
210     return (
211         counter,
212         (),
213         NormalStringLiteralExpression(string=expression.string),
214     )
215
216 NormalStructureLiteralExpression = collections.namedtuple(
217     'NormalStructureLiteralExpression',
218     [
219         'field_count',
220         'symbol_list_variable',
221         'value_list_variable',
222     ],
223 )
224
225 def normalize_structure_literal_expression(counter, expression):
226     prestatements = []
227     field_symbol_array = []
228     field_value_array = []
229
230     for symbol_expression_pair in expression.fields:
231         counter, field_prestatements, field_expression = normalize_expression(
232             counter,
233             symbol_expression_pair.expression,
234         )
235
236         for p in field_prestatements:
237             prestatements.append(p)
238
239         field_symbol_array.append(symbol_expression_pair.symbol)
240         field_value_array.append(field_expression)
241
242     symbol_array_variable = '${}'.format(counter)
243     counter += 1
244
245     prestatements.append(
246         NormalSymbolArrayVariableInitializationStatement(
247             variable=symbol_array_variable,
248             symbol_list=tuple(field_symbol_array),
249         )
250     )
251
252     value_array_variable = '${}'.format(counter)
253     counter += 1
254
255     prestatements.append(
256         NormalArrayVariableInitializationStatement(
257             variable=value_array_variable,
258             items=tuple(field_value_array),
259         )
260     )
261
262     variable = '${}'.format(counter)
263
264     prestatements.append(
265         NormalVariableInitializationStatement(
266             variable=variable,
267             expression=NormalStructureLiteralExpression(
268                 field_count=len(expression.fields),
269                 symbol_list_variable=symbol_array_variable,
270                 value_list_variable=value_array_variable,
271             ),
272         )
273     )
274
275     return (
276         counter + 1,
277         tuple(prestatements),
278         NormalVariableExpression(variable=variable),
279     )
280
281
282 def normalize_symbol_expression(counter, expression):
283     variable = '${}'.format(counter)
284     return (
285         counter + 1,
286         (
287             NormalVariableInitializationStatement(
288                 variable=variable,
289                 expression=NormalSymbolExpression(symbol=expression.symbol),
290             ),
291         ),
292         NormalVariableExpression(variable=variable),
293     )
294
295 def normalize_function_call_expression(counter, expression):
296     prestatements = []
297
298     for argument in expression.argument_list:
299         counter, argument_prestatements, normalized_argument = normalize_expression(counter, argument)
300
301         for s in argument_prestatements:
302             prestatements.append(s)
303
304         prestatements.append(
305             NormalPushStatement(
306                 expression=normalized_argument,
307             ),
308         )
309
310     counter, function_prestatements, function_expression = normalize_expression(
311         counter,
312         expression.function,
313     )
314
315     for ps in function_prestatements:
316         prestatements.append(ps)
317
318     if not isinstance(function_expression, NormalVariableExpression):
319         function_variable = '${}'.format(counter)
320
321         prestatements.append(
322             NormalVariableInitializationStatement(
323                 variable=function_variable,
324                 expression=function_expression,
325             )
326         )
327
328         function_expression = NormalVariableExpression(variable=function_variable)
329         counter += 1
330
331     result_variable = '${}'.format(counter)
332
333     prestatements.append(
334         NormalVariableInitializationStatement(
335             variable=result_variable,
336             expression=NormalFunctionCallExpression(
337                 metadata=expression.metadata,
338                 function_expression=function_expression,
339                 argument_count=len(expression.argument_list),
340             ),
341         )
342     )
343
344     return (
345         counter + 1,
346         tuple(prestatements),
347         NormalVariableExpression(variable=result_variable),
348     )
349
350 def normalize_if_expression(counter, expression):
351     counter, condition_prestatements, condition_expression = normalize_expression(
352         counter,
353         expression.condition_expression,
354     )
355
356     result_variable = '${}'.format(counter)
357     counter += 1
358
359     counter, if_statement_list = normalize_statement_list(
360         counter,
361         expression.if_statement_list,
362         assign_result_to=result_variable,
363     )
364     counter, else_statement_list = normalize_statement_list(
365         counter,
366         expression.else_statement_list,
367         assign_result_to=result_variable,
368     )
369
370     return (
371         counter,
372         condition_prestatements + (
373             NormalVariableInitializationStatement(
374                 variable=result_variable,
375                 expression=NormalVariableExpression(variable='builtin$nil'),
376             ),
377             NormalIfElseStatement(
378                 condition_expression=condition_expression,
379                 if_statement_list=if_statement_list,
380                 else_statement_list=else_statement_list,
381             ),
382         ),
383         NormalVariableExpression(variable=result_variable),
384     )
385
386 def normalize_expression(counter, expression):
387     return {
388         desugaring.DesugaredFunctionCallExpression: normalize_function_call_expression,
389         desugaring.DesugaredIfExpression: normalize_if_expression,
390         desugaring.DesugaredIntegerLiteralExpression: normalize_integer_literal_expression,
391         desugaring.DesugaredLambdaExpression: normalize_lambda_expression,
392         desugaring.DesugaredListLiteralExpression: normalize_list_literal_expression,
393         desugaring.DesugaredStringLiteralExpression: normalize_string_literal_expression,
394         desugaring.DesugaredStructureLiteralExpression: normalize_structure_literal_expression,
395         desugaring.DesugaredSymbolExpression: normalize_symbol_expression,
396     }[type(expression)](counter, expression)
397
398 def normalize_expression_statement(counter, statement):
399     # TODO Normalized will be a NormalVariableExpression, which will go unused
400     # for expression statements in every case except when it's a return
401     # statement. This cases warnings on C compilation. We should only generate
402     # this variable when it will be used on return.
403     counter, prestatements, normalized = normalize_expression(counter, statement.expression)
404
405     return (
406         counter,
407         prestatements,
408         NormalExpressionStatement(expression=normalized),
409     )
410
411 def normalize_assignment_statement(counter, statement):
412     counter, prestatements, normalized_expression = normalize_expression(counter, statement.expression)
413     return (
414         counter,
415         prestatements,
416         NormalAssignmentStatement(
417             target=statement.target,
418             expression=normalized_expression,
419         ),
420     )
421
422 def normalize_statement(counter, statement):
423     return {
424         desugaring.DesugaredAssignmentStatement: normalize_assignment_statement,
425         desugaring.DesugaredExpressionStatement: normalize_expression_statement,
426     }[type(statement)](counter, statement)
427
428 @util.force_generator(tuple)
429 def normalize_statement_list(counter, statement_list, **kwargs):
430     assign_result_to = kwargs.pop('assign_result_to', None)
431
432     assert len(kwargs) == 0
433
434     result_statement_list = []
435
436     for statement in statement_list:
437         counter, prestatements, normalized = normalize_statement(counter, statement)
438         for s in prestatements:
439             result_statement_list.append(s)
440         result_statement_list.append(normalized)
441
442     # TODO The way we fix the last statement is really confusing
443     last_statement = result_statement_list[-1]
444
445     if isinstance(last_statement, NormalExpressionStatement) and isinstance(last_statement.expression, NormalVariableExpression):
446         if assign_result_to is not None:
447             result_expression = result_statement_list.pop().expression
448             result_statement_list.append(
449                 NormalVariableReassignmentStatement(
450                     variable=assign_result_to,
451                     expression=result_expression,
452                 )
453             )
454
455     return (
456         counter,
457         result_statement_list,
458     )
459
460 def normalize(program):
461     _, statement_list = normalize_statement_list(0, program.statement_list)
462
463     return NormalProgram(
464         statement_list=statement_list,
465     )