Move desugaring ternary comparison operators into the normalization step
[fur] / normalization.py
1 import collections
2
3 import parsing
4 import util
5
6 NormalVariableExpression = collections.namedtuple(
7     'NormalVariableExpression',
8     [
9         'variable',
10     ],
11 )
12
13 NormalIntegerLiteralExpression = collections.namedtuple(
14     'NormalIntegerLiteralExpression',
15     [
16         'integer',
17     ],
18 )
19
20 NormalStringLiteralExpression = collections.namedtuple(
21     'NormalStringLiteralExpression',
22     [
23         'string',
24     ],
25 )
26
27 NormalSymbolExpression = collections.namedtuple(
28     'NormalSymbolExpression',
29     [
30         'symbol',
31     ],
32 )
33
34 NormalNegationExpression = collections.namedtuple(
35     'NormalNegationExpression',
36     [
37         'internal_expression',
38     ],
39 )
40
41 NormalDotExpression = collections.namedtuple(
42     'NormalDotExpression',
43     [
44         'instance',
45         'field',
46     ],
47 )
48
49 NormalInfixExpression = collections.namedtuple(
50     'NormalInfixExpression',
51     [
52         'order',
53         'operator',
54         'left',
55         'right',
56     ],
57 )
58
59 NormalPushStatement = collections.namedtuple(
60     'NormalPushStatement',
61     (
62         'expression',
63     ),
64 )
65
66 NormalFunctionCallExpression = collections.namedtuple(
67     'NormalFunctionCallExpression',
68     [
69         'function_expression',
70         'argument_count',
71     ],
72 )
73
74 NormalArrayVariableInitializationStatement = collections.namedtuple(
75     'NormalArrayVariableInitializationStatement',
76     [
77         'variable',
78         'items',
79     ],
80 )
81
82 NormalSymbolArrayVariableInitializationStatement = collections.namedtuple(
83     'NormalSymbolArrayVariableInitializationStatement',
84     [
85         'variable',
86         'symbol_list',
87     ],
88 )
89
90 NormalVariableInitializationStatement = collections.namedtuple(
91     'NormalVariableInitializationStatement',
92     [
93         'variable',
94         'expression',
95     ],
96 )
97
98 NormalVariableReassignmentStatement = collections.namedtuple(
99     'NormalVariableReassignmentStatement',
100     [
101         'variable',
102         'expression',
103     ],
104 )
105
106 NormalExpressionStatement = collections.namedtuple(
107     'NormalExpressionStatement',
108     [
109         'expression',
110     ],
111 )
112
113 NormalAssignmentStatement = collections.namedtuple(
114     'NormalAssignmentStatement',
115     [
116         'target',
117         'expression',
118     ],
119 )
120
121 NormalIfElseStatement = collections.namedtuple(
122     'NormalIfElseStatement',
123     [
124         'condition_expression',
125         'if_statement_list',
126         'else_statement_list',
127     ],
128 )
129
130 NormalFunctionDefinitionStatement = collections.namedtuple(
131     'NormalFunctionDefinitionStatement',
132     [
133         'name',
134         'argument_name_list',
135         'statement_list',
136     ],
137 )
138
139 NormalProgram = collections.namedtuple(
140     'NormalProgram',
141     [
142         'statement_list',
143     ],
144 )
145
146 def fake_normalization(counter, thing):
147     return (counter, (), thing)
148
149 def normalize_integer_literal_expression(counter, expression):
150     variable = '${}'.format(counter)
151     return (
152         counter + 1,
153         (
154             NormalVariableInitializationStatement(
155                 variable=variable,
156                 expression=NormalIntegerLiteralExpression(integer=expression.integer),
157             ),
158         ),
159         NormalVariableExpression(variable=variable),
160     )
161
162 NormalListConstructExpression = collections.namedtuple(
163     'NormalListConstructExpression',
164     [
165         'allocate',
166     ],
167 )
168
169 NormalListAppendStatement = collections.namedtuple(
170     'NormalListAppendStatement',
171     [
172         'list_expression',
173         'item_expression',
174     ],
175 )
176
177 NormalListGetExpression = collections.namedtuple(
178     'NormalListGetExpression',
179     [
180         'list_expression',
181         'index_expression',
182     ],
183 )
184
185 def normalize_list_literal_expression(counter, expression):
186     list_variable = '${}'.format(counter)
187     counter += 1
188
189     prestatements = [
190         NormalVariableInitializationStatement(
191             variable=list_variable,
192             expression=NormalListConstructExpression(allocate=len(expression.item_expression_list)),
193         ),
194     ]
195
196     list_expression = NormalVariableExpression(variable=list_variable)
197
198     for item_expression in expression.item_expression_list:
199         counter, item_expression_prestatements, normalized = normalize_expression(
200             counter,
201             item_expression,
202         )
203
204         for p in item_expression_prestatements:
205             prestatements.append(p)
206
207         prestatements.append(
208             NormalListAppendStatement(
209                 list_expression=list_expression,
210                 item_expression=normalized,
211             )
212         )
213
214     return (
215         counter,
216         tuple(prestatements),
217         list_expression,
218     )
219
220 def normalize_list_item_expression(counter, expression):
221     counter, list_prestatements, list_expression = normalize_expression(counter, expression.list_expression)
222     counter, index_prestatements, index_expression = normalize_expression(counter, expression.index_expression)
223
224     result_variable = '${}'.format(counter)
225     result_prestatement = NormalVariableInitializationStatement(
226         variable=result_variable,
227         expression=NormalListGetExpression(
228             list_expression=list_expression,
229             index_expression=index_expression,
230         ),
231     )
232
233     return (
234         counter + 1,
235         list_prestatements + index_prestatements + (result_prestatement,),
236         NormalVariableExpression(variable=result_variable),
237     )
238
239 def normalize_string_literal_expression(counter, expression):
240     variable = '${}'.format(counter)
241     return (
242         counter + 1,
243         (
244             NormalVariableInitializationStatement(
245                 variable=variable,
246                 expression=NormalStringLiteralExpression(string=expression.string),
247             ),
248         ),
249         NormalVariableExpression(variable=variable),
250     )
251
252 NormalStructureLiteralExpression = collections.namedtuple(
253     'NormalStructureLiteralExpression',
254     [
255         'field_count',
256         'symbol_list_variable',
257         'value_list_variable',
258     ],
259 )
260
261 def normalize_structure_literal_expression(counter, expression):
262     prestatements = []
263     field_symbol_array = []
264     field_value_array = []
265
266     for symbol_expression_pair in expression.fields:
267         counter, field_prestatements, field_expression = normalize_expression(
268             counter,
269             symbol_expression_pair.expression,
270         )
271
272         for p in field_prestatements:
273             prestatements.append(p)
274
275         field_symbol_array.append(symbol_expression_pair.symbol)
276         field_value_array.append(field_expression)
277
278     symbol_array_variable = '${}'.format(counter)
279     counter += 1
280
281     prestatements.append(
282         NormalSymbolArrayVariableInitializationStatement(
283             variable=symbol_array_variable,
284             symbol_list=tuple(field_symbol_array),
285         )
286     )
287
288     value_array_variable = '${}'.format(counter)
289     counter += 1
290
291     prestatements.append(
292         NormalArrayVariableInitializationStatement(
293             variable=value_array_variable,
294             items=tuple(field_value_array),
295         )
296     )
297
298     variable = '${}'.format(counter)
299
300     prestatements.append(
301         NormalVariableInitializationStatement(
302             variable=variable,
303             expression=NormalStructureLiteralExpression(
304                 field_count=len(expression.fields),
305                 symbol_list_variable=symbol_array_variable,
306                 value_list_variable=value_array_variable,
307             ),
308         )
309     )
310
311     return (
312         counter + 1,
313         tuple(prestatements),
314         NormalVariableExpression(variable=variable),
315     )
316
317
318 def normalize_symbol_expression(counter, expression):
319     variable = '${}'.format(counter)
320     return (
321         counter + 1,
322         (
323             NormalVariableInitializationStatement(
324                 variable=variable,
325                 expression=NormalSymbolExpression(symbol=expression.symbol),
326             ),
327         ),
328         NormalVariableExpression(variable=variable),
329     )
330
331 def normalize_function_call_expression(counter, expression):
332     assert isinstance(expression, parsing.FurFunctionCallExpression)
333
334     prestatements = []
335
336     for argument in expression.arguments:
337         counter, argument_prestatements, normalized_argument = normalize_expression(counter, argument)
338
339         for s in argument_prestatements:
340             prestatements.append(s)
341
342         variable = '${}'.format(counter)
343         prestatements.append(
344             NormalVariableInitializationStatement(
345                 variable=variable,
346                 expression=normalized_argument,
347             )
348         )
349         prestatements.append(
350             NormalPushStatement(
351                 expression=NormalVariableExpression(
352                     variable=variable,
353                 ),
354             ),
355         )
356         counter += 1
357
358     counter, function_prestatements, function_expression = normalize_expression(
359         counter,
360         expression.function,
361     )
362
363     for ps in function_prestatements:
364         prestatements.append(ps)
365
366     if not isinstance(function_expression, NormalVariableExpression):
367         function_variable = '${}'.format(counter)
368
369         prestatements.append(
370             NormalVariableInitializationStatement(
371                 variable=function_variable,
372                 expression=function_expression,
373             )
374         )
375
376         function_expression = NormalVariableExpression(variable=function_variable)
377         counter += 1
378
379     result_variable = '${}'.format(counter)
380
381     prestatements.append(
382         NormalVariableInitializationStatement(
383             variable=result_variable,
384             expression=NormalFunctionCallExpression(
385                 function_expression=function_expression,
386                 argument_count=len(expression.arguments),
387             ),
388         )
389     )
390
391     return (
392         counter + 1,
393         tuple(prestatements),
394         NormalVariableExpression(variable=result_variable),
395     )
396
397 def normalize_basic_infix_operation(counter, expression):
398     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left)
399     counter, right_prestatements, right_expression = normalize_expression(counter, expression.right)
400
401     left_variable = '${}'.format(counter)
402     counter += 1
403     right_variable = '${}'.format(counter)
404     counter += 1
405     center_variable = '${}'.format(counter)
406     counter += 1
407
408     root_prestatements = (
409         NormalVariableInitializationStatement(
410             variable=left_variable,
411             expression=left_expression,
412         ),
413         NormalVariableInitializationStatement(
414             variable=right_variable,
415             expression=right_expression,
416         ),
417         NormalVariableInitializationStatement(
418             variable=center_variable,
419             expression=NormalInfixExpression(
420                 order=expression.order,
421                 operator=expression.operator,
422                 left=NormalVariableExpression(variable=left_variable),
423                 right=NormalVariableExpression(variable=right_variable),
424             ),
425         ),
426     )
427
428     return (
429         counter,
430         left_prestatements + right_prestatements + root_prestatements,
431         NormalVariableExpression(variable=center_variable),
432     )
433
434 def desugar_ternary_comparison(counter, expression):
435     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left.left)
436     counter, middle_prestatements, middle_expression = normalize_expression(counter, expression.left.right)
437
438     left_variable = '${}'.format(counter)
439     counter += 1
440     middle_variable = '${}'.format(counter)
441     counter += 1
442
443     juncture_prestatements = (
444         NormalVariableInitializationStatement(
445             variable=left_variable,
446             expression=left_expression,
447         ),
448         NormalVariableInitializationStatement(
449             variable=middle_variable,
450             expression=middle_expression,
451         )
452     )
453
454     counter, boolean_expression_prestatements, boolean_expression =  normalize_boolean_expression(
455         counter,
456         parsing.FurInfixExpression(
457             order='and_level',
458             operator='and',
459             left=parsing.FurInfixExpression(
460                 order='comparison_level',
461                 operator=expression.left.operator,
462                 left=NormalVariableExpression(variable=left_variable),
463                 right=NormalVariableExpression(variable=middle_variable),
464             ),
465             right=parsing.FurInfixExpression(
466                 order='comparison_level',
467                 operator=expression.operator,
468                 left=NormalVariableExpression(variable=middle_variable),
469                 right=expression.right,
470             ),
471         )
472     )
473
474     return (
475         counter,
476         left_prestatements + middle_prestatements + juncture_prestatements + boolean_expression_prestatements,
477         boolean_expression,
478     )
479
480 def normalize_comparison_expression(counter, expression):
481     if isinstance(expression.left, parsing.FurInfixExpression) and expression.order == 'comparison_level':
482         return desugar_ternary_comparison(counter, expression)
483
484     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left)
485     counter, right_prestatements, right_expression = normalize_expression(counter, expression.right)
486
487     left_variable = '${}'.format(counter)
488     counter += 1
489     right_variable = '${}'.format(counter)
490     counter += 1
491
492     root_prestatements = (
493         NormalVariableInitializationStatement(
494             variable=left_variable,
495             expression=left_expression,
496         ),
497         NormalVariableInitializationStatement(
498             variable=right_variable,
499             expression=right_expression,
500         ),
501     )
502
503     counter, result_prestatements, result_expression = (
504         counter,
505         left_prestatements + right_prestatements + root_prestatements,
506         NormalInfixExpression(
507             order=expression.order,
508             operator=expression.operator,
509             left=NormalVariableExpression(variable=left_variable),
510             right=NormalVariableExpression(variable=right_variable),
511         ),
512     )
513
514     return (counter, result_prestatements, result_expression)
515
516 def normalize_boolean_expression(counter, expression):
517     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left)
518     counter, right_prestatements, right_expression = normalize_expression(counter, expression.right)
519
520     result_variable = '${}'.format(counter)
521     if_else_prestatment = NormalVariableInitializationStatement(
522         variable=result_variable,
523         expression=left_expression,
524     )
525     counter += 1
526
527     condition_expression=NormalVariableExpression(variable=result_variable)
528     short_circuited_statements = right_prestatements + (NormalVariableReassignmentStatement(variable=result_variable, expression=right_expression),)
529
530     if expression.operator == 'and':
531         if_else_statement = NormalIfElseStatement(
532             condition_expression=condition_expression,
533             if_statement_list=short_circuited_statements,
534             else_statement_list=(),
535         )
536
537     elif expression.operator == 'or':
538         if_else_statement = NormalIfElseStatement(
539             condition_expression=condition_expression,
540             if_statement_list=(),
541             else_statement_list=short_circuited_statements,
542         )
543
544     else:
545         raise Exception('Unable to handle operator "{}"'.format(expression.operator))
546
547     return (
548         counter,
549         left_prestatements + (if_else_prestatment, if_else_statement),
550         NormalVariableExpression(variable=result_variable),
551     )
552
553 def normalize_dot_expression(counter, expression):
554     assert isinstance(expression.right, parsing.FurSymbolExpression)
555
556     counter, prestatements, left_expression = normalize_expression(counter, expression.left)
557
558     variable = '${}'.format(counter)
559
560     dot_expression_prestatement = NormalVariableInitializationStatement(
561         variable=variable,
562         expression=NormalDotExpression(
563             instance=left_expression,
564             field=expression.right.symbol,
565         ),
566     )
567
568     return (
569         counter + 1,
570         prestatements + (dot_expression_prestatement,),
571         NormalVariableExpression(variable=variable),
572     )
573
574 def normalize_infix_expression(counter, expression):
575     return {
576         'multiplication_level': normalize_basic_infix_operation,
577         'addition_level': normalize_basic_infix_operation,
578         'comparison_level': normalize_comparison_expression,
579         'dot_level': normalize_dot_expression,
580         'and_level': normalize_boolean_expression,
581         'or_level': normalize_boolean_expression,
582     }[expression.order](counter, expression)
583
584 def normalize_if_expression(counter, expression):
585     counter, condition_prestatements, condition_expression = normalize_expression(
586         counter,
587         expression.condition_expression,
588     )
589
590     result_variable = '${}'.format(counter)
591     counter += 1
592
593     counter, if_statement_list = normalize_statement_list(
594         counter,
595         expression.if_statement_list,
596         assign_result_to=result_variable,
597     )
598     counter, else_statement_list = normalize_statement_list(
599         counter,
600         expression.else_statement_list,
601         assign_result_to=result_variable,
602     )
603
604     return (
605         counter,
606         condition_prestatements + (
607             NormalVariableInitializationStatement(
608                 variable=result_variable,
609                 expression=NormalVariableExpression(variable='builtin$nil'),
610             ),
611             NormalIfElseStatement(
612                 condition_expression=condition_expression,
613                 if_statement_list=if_statement_list,
614                 else_statement_list=else_statement_list,
615             ),
616         ),
617         NormalVariableExpression(variable=result_variable),
618     )
619
620 def normalize_negation_expression(counter, expression):
621     counter, prestatements, internal_expression = normalize_expression(counter, expression.value)
622
623     internal_variable = '${}'.format(counter)
624     counter += 1
625
626     return (
627         counter,
628         prestatements + (
629             NormalVariableInitializationStatement(
630                 variable=internal_variable,
631                 expression=internal_expression,
632             ),
633         ),
634         NormalNegationExpression(internal_expression=NormalVariableExpression(variable=internal_variable)),
635     )
636
637 def normalize_expression(counter, expression):
638     return {
639         NormalInfixExpression: fake_normalization,
640         NormalVariableExpression: fake_normalization,
641         parsing.FurFunctionCallExpression: normalize_function_call_expression,
642         parsing.FurIfExpression: normalize_if_expression,
643         parsing.FurInfixExpression: normalize_infix_expression,
644         parsing.FurIntegerLiteralExpression: normalize_integer_literal_expression,
645         parsing.FurListLiteralExpression: normalize_list_literal_expression,
646         parsing.FurListItemExpression: normalize_list_item_expression,
647         parsing.FurNegationExpression: normalize_negation_expression,
648         parsing.FurStringLiteralExpression: normalize_string_literal_expression,
649         parsing.FurStructureLiteralExpression: normalize_structure_literal_expression,
650         parsing.FurSymbolExpression: normalize_symbol_expression,
651     }[type(expression)](counter, expression)
652
653 def normalize_expression_statement(counter, statement):
654     # TODO Normalized will be a NormalVariableExpression, which will go unused
655     # for expression statements in every case except when it's a return
656     # statement. This cases warnings on C compilation. We should only generate
657     # this variable when it will be used on return.
658     counter, prestatements, normalized = normalize_expression(counter, statement.expression)
659
660     return (
661         counter,
662         prestatements,
663         NormalExpressionStatement(expression=normalized),
664     )
665
666 def normalize_function_definition_statement(counter, statement):
667     _, statement_list = normalize_statement_list(
668         0,
669         statement.statement_list,
670         assign_result_to='result',
671     )
672     return (
673         counter,
674         (),
675         NormalFunctionDefinitionStatement(
676             name=statement.name,
677             argument_name_list=statement.argument_name_list,
678             statement_list=statement_list,
679         ),
680     )
681
682 def normalize_assignment_statement(counter, statement):
683     counter, prestatements, normalized_expression = normalize_expression(counter, statement.expression)
684     return (
685         counter,
686         prestatements,
687         NormalAssignmentStatement(
688             target=statement.target,
689             expression=normalized_expression,
690         ),
691     )
692
693 def normalize_statement(counter, statement):
694     return {
695         parsing.FurAssignmentStatement: normalize_assignment_statement,
696         parsing.FurExpressionStatement: normalize_expression_statement,
697         parsing.FurFunctionDefinitionStatement: normalize_function_definition_statement,
698     }[type(statement)](counter, statement)
699
700 @util.force_generator(tuple)
701 def normalize_statement_list(counter, statement_list, **kwargs):
702     assign_result_to = kwargs.pop('assign_result_to', None)
703
704     assert len(kwargs) == 0
705
706     result_statement_list = []
707
708     for statement in statement_list:
709         counter, prestatements, normalized = normalize_statement(counter, statement)
710         for s in prestatements:
711             result_statement_list.append(s)
712         result_statement_list.append(normalized)
713
714     # TODO The way we fix the last statement is really confusing
715     last_statement = result_statement_list[-1]
716
717     if isinstance(last_statement, NormalExpressionStatement) and isinstance(last_statement.expression, NormalVariableExpression):
718         if assign_result_to is not None:
719             result_expression = result_statement_list.pop().expression
720             result_statement_list.append(
721                 NormalVariableReassignmentStatement(
722                     variable=assign_result_to,
723                     expression=result_expression,
724                 )
725             )
726
727     return (
728         counter,
729         result_statement_list,
730     )
731
732 def normalize(program):
733     _, statement_list = normalize_statement_list(0, program.statement_list)
734
735     return NormalProgram(
736         statement_list=statement_list,
737     )