Removed one more test using ternary comparison
[fur] / normalization.py
1 import collections
2
3 import parsing
4 import util
5
6 NormalVariableExpression = collections.namedtuple(
7     'NormalVariableExpression',
8     [
9         'variable',
10     ],
11 )
12
13 NormalIntegerLiteralExpression = collections.namedtuple(
14     'NormalIntegerLiteralExpression',
15     [
16         'integer',
17     ],
18 )
19
20 NormalStringLiteralExpression = collections.namedtuple(
21     'NormalStringLiteralExpression',
22     [
23         'string',
24     ],
25 )
26
27 NormalSymbolExpression = collections.namedtuple(
28     'NormalSymbolExpression',
29     [
30         'symbol',
31     ],
32 )
33
34 NormalNegationExpression = collections.namedtuple(
35     'NormalNegationExpression',
36     [
37         'internal_expression',
38     ],
39 )
40
41 NormalDotExpression = collections.namedtuple(
42     'NormalDotExpression',
43     [
44         'instance',
45         'field',
46     ],
47 )
48
49 NormalInfixExpression = collections.namedtuple(
50     'NormalInfixExpression',
51     [
52         'metadata',
53         'order',
54         'operator',
55     ],
56 )
57
58 NormalPushStatement = collections.namedtuple(
59     'NormalPushStatement',
60     (
61         'expression',
62     ),
63 )
64
65 NormalFunctionCallExpression = collections.namedtuple(
66     'NormalFunctionCallExpression',
67     [
68         'function_expression',
69         'argument_count',
70     ],
71 )
72
73 NormalArrayVariableInitializationStatement = collections.namedtuple(
74     'NormalArrayVariableInitializationStatement',
75     [
76         'variable',
77         'items',
78     ],
79 )
80
81 NormalSymbolArrayVariableInitializationStatement = collections.namedtuple(
82     'NormalSymbolArrayVariableInitializationStatement',
83     [
84         'variable',
85         'symbol_list',
86     ],
87 )
88
89 NormalVariableInitializationStatement = collections.namedtuple(
90     'NormalVariableInitializationStatement',
91     [
92         'variable',
93         'expression',
94     ],
95 )
96
97 NormalVariableReassignmentStatement = collections.namedtuple(
98     'NormalVariableReassignmentStatement',
99     [
100         'variable',
101         'expression',
102     ],
103 )
104
105 NormalExpressionStatement = collections.namedtuple(
106     'NormalExpressionStatement',
107     [
108         'expression',
109     ],
110 )
111
112 NormalAssignmentStatement = collections.namedtuple(
113     'NormalAssignmentStatement',
114     [
115         'target',
116         'expression',
117     ],
118 )
119
120 NormalIfElseStatement = collections.namedtuple(
121     'NormalIfElseStatement',
122     [
123         'condition_expression',
124         'if_statement_list',
125         'else_statement_list',
126     ],
127 )
128
129 NormalFunctionDefinitionStatement = collections.namedtuple(
130     'NormalFunctionDefinitionStatement',
131     [
132         'name',
133         'argument_name_list',
134         'statement_list',
135     ],
136 )
137
138 NormalProgram = collections.namedtuple(
139     'NormalProgram',
140     [
141         'statement_list',
142     ],
143 )
144
145 def fake_normalization(counter, thing):
146     return (counter, (), thing)
147
148 def normalize_integer_literal_expression(counter, expression):
149     variable = '${}'.format(counter)
150     return (
151         counter + 1,
152         (
153             NormalVariableInitializationStatement(
154                 variable=variable,
155                 expression=NormalIntegerLiteralExpression(integer=expression.integer),
156             ),
157         ),
158         NormalVariableExpression(variable=variable),
159     )
160
161 NormalListConstructExpression = collections.namedtuple(
162     'NormalListConstructExpression',
163     [
164         'allocate',
165     ],
166 )
167
168 NormalListAppendStatement = collections.namedtuple(
169     'NormalListAppendStatement',
170     [
171         'list_expression',
172         'item_expression',
173     ],
174 )
175
176 NormalListGetExpression = collections.namedtuple(
177     'NormalListGetExpression',
178     [
179         'list_expression',
180         'index_expression',
181     ],
182 )
183
184 def normalize_list_literal_expression(counter, expression):
185     list_variable = '${}'.format(counter)
186     counter += 1
187
188     prestatements = [
189         NormalVariableInitializationStatement(
190             variable=list_variable,
191             expression=NormalListConstructExpression(allocate=len(expression.item_expression_list)),
192         ),
193     ]
194
195     list_expression = NormalVariableExpression(variable=list_variable)
196
197     for item_expression in expression.item_expression_list:
198         counter, item_expression_prestatements, normalized = normalize_expression(
199             counter,
200             item_expression,
201         )
202
203         for p in item_expression_prestatements:
204             prestatements.append(p)
205
206         prestatements.append(
207             NormalListAppendStatement(
208                 list_expression=list_expression,
209                 item_expression=normalized,
210             )
211         )
212
213     return (
214         counter,
215         tuple(prestatements),
216         list_expression,
217     )
218
219 def normalize_list_item_expression(counter, expression):
220     counter, list_prestatements, list_expression = normalize_expression(counter, expression.list_expression)
221     counter, index_prestatements, index_expression = normalize_expression(counter, expression.index_expression)
222
223     result_variable = '${}'.format(counter)
224     result_prestatement = NormalVariableInitializationStatement(
225         variable=result_variable,
226         expression=NormalListGetExpression(
227             list_expression=list_expression,
228             index_expression=index_expression,
229         ),
230     )
231
232     return (
233         counter + 1,
234         list_prestatements + index_prestatements + (result_prestatement,),
235         NormalVariableExpression(variable=result_variable),
236     )
237
238 def normalize_string_literal_expression(counter, expression):
239     variable = '${}'.format(counter)
240     return (
241         counter + 1,
242         (
243             NormalVariableInitializationStatement(
244                 variable=variable,
245                 expression=NormalStringLiteralExpression(string=expression.string),
246             ),
247         ),
248         NormalVariableExpression(variable=variable),
249     )
250
251 NormalStructureLiteralExpression = collections.namedtuple(
252     'NormalStructureLiteralExpression',
253     [
254         'field_count',
255         'symbol_list_variable',
256         'value_list_variable',
257     ],
258 )
259
260 def normalize_structure_literal_expression(counter, expression):
261     prestatements = []
262     field_symbol_array = []
263     field_value_array = []
264
265     for symbol_expression_pair in expression.fields:
266         counter, field_prestatements, field_expression = normalize_expression(
267             counter,
268             symbol_expression_pair.expression,
269         )
270
271         for p in field_prestatements:
272             prestatements.append(p)
273
274         field_symbol_array.append(symbol_expression_pair.symbol)
275         field_value_array.append(field_expression)
276
277     symbol_array_variable = '${}'.format(counter)
278     counter += 1
279
280     prestatements.append(
281         NormalSymbolArrayVariableInitializationStatement(
282             variable=symbol_array_variable,
283             symbol_list=tuple(field_symbol_array),
284         )
285     )
286
287     value_array_variable = '${}'.format(counter)
288     counter += 1
289
290     prestatements.append(
291         NormalArrayVariableInitializationStatement(
292             variable=value_array_variable,
293             items=tuple(field_value_array),
294         )
295     )
296
297     variable = '${}'.format(counter)
298
299     prestatements.append(
300         NormalVariableInitializationStatement(
301             variable=variable,
302             expression=NormalStructureLiteralExpression(
303                 field_count=len(expression.fields),
304                 symbol_list_variable=symbol_array_variable,
305                 value_list_variable=value_array_variable,
306             ),
307         )
308     )
309
310     return (
311         counter + 1,
312         tuple(prestatements),
313         NormalVariableExpression(variable=variable),
314     )
315
316
317 def normalize_symbol_expression(counter, expression):
318     variable = '${}'.format(counter)
319     return (
320         counter + 1,
321         (
322             NormalVariableInitializationStatement(
323                 variable=variable,
324                 expression=NormalSymbolExpression(symbol=expression.symbol),
325             ),
326         ),
327         NormalVariableExpression(variable=variable),
328     )
329
330 def normalize_function_call_expression(counter, expression):
331     assert isinstance(expression, parsing.FurFunctionCallExpression)
332
333     prestatements = []
334
335     for argument in expression.arguments:
336         counter, argument_prestatements, normalized_argument = normalize_expression(counter, argument)
337
338         for s in argument_prestatements:
339             prestatements.append(s)
340
341         variable = '${}'.format(counter)
342         prestatements.append(
343             NormalVariableInitializationStatement(
344                 variable=variable,
345                 expression=normalized_argument,
346             )
347         )
348         prestatements.append(
349             NormalPushStatement(
350                 expression=NormalVariableExpression(
351                     variable=variable,
352                 ),
353             ),
354         )
355         counter += 1
356
357     counter, function_prestatements, function_expression = normalize_expression(
358         counter,
359         expression.function,
360     )
361
362     for ps in function_prestatements:
363         prestatements.append(ps)
364
365     if not isinstance(function_expression, NormalVariableExpression):
366         function_variable = '${}'.format(counter)
367
368         prestatements.append(
369             NormalVariableInitializationStatement(
370                 variable=function_variable,
371                 expression=function_expression,
372             )
373         )
374
375         function_expression = NormalVariableExpression(variable=function_variable)
376         counter += 1
377
378     result_variable = '${}'.format(counter)
379
380     prestatements.append(
381         NormalVariableInitializationStatement(
382             variable=result_variable,
383             expression=NormalFunctionCallExpression(
384                 function_expression=function_expression,
385                 argument_count=len(expression.arguments),
386             ),
387         )
388     )
389
390     return (
391         counter + 1,
392         tuple(prestatements),
393         NormalVariableExpression(variable=result_variable),
394     )
395
396 def normalize_basic_infix_operation(counter, expression):
397     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left)
398     counter, right_prestatements, right_expression = normalize_expression(counter, expression.right)
399
400     center_variable = '${}'.format(counter)
401     counter += 1
402
403     root_prestatements = (
404         NormalPushStatement(expression=left_expression),
405         NormalPushStatement(expression=right_expression),
406         NormalVariableInitializationStatement(
407             variable=center_variable,
408             expression=NormalInfixExpression(
409                 metadata=expression.metadata,
410                 order=expression.order,
411                 operator=expression.operator,
412             ),
413         ),
414     )
415
416     return (
417         counter,
418         left_prestatements + right_prestatements + root_prestatements,
419         NormalVariableExpression(variable=center_variable),
420     )
421
422 def desugar_ternary_comparison(counter, expression):
423     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left.left)
424     counter, middle_prestatements, middle_expression = normalize_expression(counter, expression.left.right)
425
426     left_variable = '${}'.format(counter)
427     counter += 1
428     middle_variable = '${}'.format(counter)
429     counter += 1
430
431     # TODO Is there a memory leak if the middle expression throws an exception because the first expression result hasn't been added to the stack?
432     juncture_prestatements = (
433         NormalVariableInitializationStatement(
434             variable=left_variable,
435             expression=left_expression,
436         ),
437         NormalVariableInitializationStatement(
438             variable=middle_variable,
439             expression=middle_expression,
440         )
441     )
442
443     counter, boolean_expression_prestatements, boolean_expression =  normalize_boolean_expression(
444         counter,
445         parsing.FurInfixExpression(
446             metadata=expression.left.metadata,
447             order='and_level',
448             operator='and',
449             left=parsing.FurInfixExpression(
450                 metadata=expression.left.metadata,
451                 order='comparison_level',
452                 operator=expression.left.operator,
453                 left=NormalVariableExpression(variable=left_variable),
454                 right=NormalVariableExpression(variable=middle_variable),
455             ),
456             right=parsing.FurInfixExpression(
457                 metadata=expression.metadata,
458                 order='comparison_level',
459                 operator=expression.operator,
460                 left=NormalVariableExpression(variable=middle_variable),
461                 right=expression.right,
462             ),
463         )
464     )
465
466     return (
467         counter,
468         left_prestatements + middle_prestatements + juncture_prestatements + boolean_expression_prestatements,
469         boolean_expression,
470     )
471
472 def normalize_comparison_expression(counter, expression):
473     if isinstance(expression.left, parsing.FurInfixExpression) and expression.order == 'comparison_level':
474         return desugar_ternary_comparison(counter, expression)
475
476     return normalize_basic_infix_operation(counter, expression)
477
478 def normalize_boolean_expression(counter, expression):
479     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left)
480     counter, right_prestatements, right_expression = normalize_expression(counter, expression.right)
481
482     result_variable = '${}'.format(counter)
483     if_else_prestatment = NormalVariableInitializationStatement(
484         variable=result_variable,
485         expression=left_expression,
486     )
487     counter += 1
488
489     condition_expression=NormalVariableExpression(variable=result_variable)
490     short_circuited_statements = right_prestatements + (NormalVariableReassignmentStatement(variable=result_variable, expression=right_expression),)
491
492     if expression.operator == 'and':
493         if_else_statement = NormalIfElseStatement(
494             condition_expression=condition_expression,
495             if_statement_list=short_circuited_statements,
496             else_statement_list=(),
497         )
498
499     elif expression.operator == 'or':
500         if_else_statement = NormalIfElseStatement(
501             condition_expression=condition_expression,
502             if_statement_list=(),
503             else_statement_list=short_circuited_statements,
504         )
505
506     else:
507         raise Exception('Unable to handle operator "{}"'.format(expression.operator))
508
509     return (
510         counter,
511         left_prestatements + (if_else_prestatment, if_else_statement),
512         NormalVariableExpression(variable=result_variable),
513     )
514
515 def normalize_dot_expression(counter, expression):
516     assert isinstance(expression.right, parsing.FurSymbolExpression)
517
518     counter, prestatements, left_expression = normalize_expression(counter, expression.left)
519
520     variable = '${}'.format(counter)
521
522     dot_expression_prestatement = NormalVariableInitializationStatement(
523         variable=variable,
524         expression=NormalDotExpression(
525             instance=left_expression,
526             field=expression.right.symbol,
527         ),
528     )
529
530     return (
531         counter + 1,
532         prestatements + (dot_expression_prestatement,),
533         NormalVariableExpression(variable=variable),
534     )
535
536 def normalize_infix_expression(counter, expression):
537     return {
538         'multiplication_level': normalize_basic_infix_operation,
539         'addition_level': normalize_basic_infix_operation,
540         'comparison_level': normalize_comparison_expression,
541         'dot_level': normalize_dot_expression,
542         'and_level': normalize_boolean_expression,
543         'or_level': normalize_boolean_expression,
544     }[expression.order](counter, expression)
545
546 def normalize_if_expression(counter, expression):
547     counter, condition_prestatements, condition_expression = normalize_expression(
548         counter,
549         expression.condition_expression,
550     )
551
552     result_variable = '${}'.format(counter)
553     counter += 1
554
555     counter, if_statement_list = normalize_statement_list(
556         counter,
557         expression.if_statement_list,
558         assign_result_to=result_variable,
559     )
560     counter, else_statement_list = normalize_statement_list(
561         counter,
562         expression.else_statement_list,
563         assign_result_to=result_variable,
564     )
565
566     return (
567         counter,
568         condition_prestatements + (
569             NormalVariableInitializationStatement(
570                 variable=result_variable,
571                 expression=NormalVariableExpression(variable='builtin$nil'),
572             ),
573             NormalIfElseStatement(
574                 condition_expression=condition_expression,
575                 if_statement_list=if_statement_list,
576                 else_statement_list=else_statement_list,
577             ),
578         ),
579         NormalVariableExpression(variable=result_variable),
580     )
581
582 def normalize_negation_expression(counter, expression):
583     counter, prestatements, internal_expression = normalize_expression(counter, expression.value)
584
585     internal_variable = '${}'.format(counter)
586     counter += 1
587
588     return (
589         counter,
590         prestatements + (
591             NormalVariableInitializationStatement(
592                 variable=internal_variable,
593                 expression=internal_expression,
594             ),
595         ),
596         NormalNegationExpression(internal_expression=NormalVariableExpression(variable=internal_variable)),
597     )
598
599 def normalize_expression(counter, expression):
600     return {
601         NormalInfixExpression: fake_normalization,
602         NormalVariableExpression: fake_normalization,
603         parsing.FurFunctionCallExpression: normalize_function_call_expression,
604         parsing.FurIfExpression: normalize_if_expression,
605         parsing.FurInfixExpression: normalize_infix_expression,
606         parsing.FurIntegerLiteralExpression: normalize_integer_literal_expression,
607         parsing.FurListLiteralExpression: normalize_list_literal_expression,
608         parsing.FurListItemExpression: normalize_list_item_expression,
609         parsing.FurNegationExpression: normalize_negation_expression,
610         parsing.FurStringLiteralExpression: normalize_string_literal_expression,
611         parsing.FurStructureLiteralExpression: normalize_structure_literal_expression,
612         parsing.FurSymbolExpression: normalize_symbol_expression,
613     }[type(expression)](counter, expression)
614
615 def normalize_expression_statement(counter, statement):
616     # TODO Normalized will be a NormalVariableExpression, which will go unused
617     # for expression statements in every case except when it's a return
618     # statement. This cases warnings on C compilation. We should only generate
619     # this variable when it will be used on return.
620     counter, prestatements, normalized = normalize_expression(counter, statement.expression)
621
622     return (
623         counter,
624         prestatements,
625         NormalExpressionStatement(expression=normalized),
626     )
627
628 def normalize_function_definition_statement(counter, statement):
629     _, statement_list = normalize_statement_list(
630         0,
631         statement.statement_list,
632         assign_result_to='result',
633     )
634     return (
635         counter,
636         (),
637         NormalFunctionDefinitionStatement(
638             name=statement.name,
639             argument_name_list=statement.argument_name_list,
640             statement_list=statement_list,
641         ),
642     )
643
644 def normalize_assignment_statement(counter, statement):
645     counter, prestatements, normalized_expression = normalize_expression(counter, statement.expression)
646     return (
647         counter,
648         prestatements,
649         NormalAssignmentStatement(
650             target=statement.target,
651             expression=normalized_expression,
652         ),
653     )
654
655 def normalize_statement(counter, statement):
656     return {
657         parsing.FurAssignmentStatement: normalize_assignment_statement,
658         parsing.FurExpressionStatement: normalize_expression_statement,
659         parsing.FurFunctionDefinitionStatement: normalize_function_definition_statement,
660     }[type(statement)](counter, statement)
661
662 @util.force_generator(tuple)
663 def normalize_statement_list(counter, statement_list, **kwargs):
664     assign_result_to = kwargs.pop('assign_result_to', None)
665
666     assert len(kwargs) == 0
667
668     result_statement_list = []
669
670     for statement in statement_list:
671         counter, prestatements, normalized = normalize_statement(counter, statement)
672         for s in prestatements:
673             result_statement_list.append(s)
674         result_statement_list.append(normalized)
675
676     # TODO The way we fix the last statement is really confusing
677     last_statement = result_statement_list[-1]
678
679     if isinstance(last_statement, NormalExpressionStatement) and isinstance(last_statement.expression, NormalVariableExpression):
680         if assign_result_to is not None:
681             result_expression = result_statement_list.pop().expression
682             result_statement_list.append(
683                 NormalVariableReassignmentStatement(
684                     variable=assign_result_to,
685                     expression=result_expression,
686                 )
687             )
688
689     return (
690         counter,
691         result_statement_list,
692     )
693
694 def normalize(program):
695     _, statement_list = normalize_statement_list(0, program.statement_list)
696
697     return NormalProgram(
698         statement_list=statement_list,
699     )