Reuse normalize_basic_infix_operation in implementation of normalize_comparison_expre...
[fur] / normalization.py
1 import collections
2
3 import parsing
4 import util
5
6 NormalVariableExpression = collections.namedtuple(
7     'NormalVariableExpression',
8     [
9         'variable',
10     ],
11 )
12
13 NormalIntegerLiteralExpression = collections.namedtuple(
14     'NormalIntegerLiteralExpression',
15     [
16         'integer',
17     ],
18 )
19
20 NormalStringLiteralExpression = collections.namedtuple(
21     'NormalStringLiteralExpression',
22     [
23         'string',
24     ],
25 )
26
27 NormalSymbolExpression = collections.namedtuple(
28     'NormalSymbolExpression',
29     [
30         'symbol',
31     ],
32 )
33
34 NormalNegationExpression = collections.namedtuple(
35     'NormalNegationExpression',
36     [
37         'internal_expression',
38     ],
39 )
40
41 NormalDotExpression = collections.namedtuple(
42     'NormalDotExpression',
43     [
44         'instance',
45         'field',
46     ],
47 )
48
49 NormalInfixExpression = collections.namedtuple(
50     'NormalInfixExpression',
51     [
52         'order',
53         'operator',
54         'left',
55         'right',
56     ],
57 )
58
59 NormalPushStatement = collections.namedtuple(
60     'NormalPushStatement',
61     (
62         'expression',
63     ),
64 )
65
66 NormalFunctionCallExpression = collections.namedtuple(
67     'NormalFunctionCallExpression',
68     [
69         'function_expression',
70         'argument_count',
71     ],
72 )
73
74 NormalArrayVariableInitializationStatement = collections.namedtuple(
75     'NormalArrayVariableInitializationStatement',
76     [
77         'variable',
78         'items',
79     ],
80 )
81
82 NormalSymbolArrayVariableInitializationStatement = collections.namedtuple(
83     'NormalSymbolArrayVariableInitializationStatement',
84     [
85         'variable',
86         'symbol_list',
87     ],
88 )
89
90 NormalVariableInitializationStatement = collections.namedtuple(
91     'NormalVariableInitializationStatement',
92     [
93         'variable',
94         'expression',
95     ],
96 )
97
98 NormalVariableReassignmentStatement = collections.namedtuple(
99     'NormalVariableReassignmentStatement',
100     [
101         'variable',
102         'expression',
103     ],
104 )
105
106 NormalExpressionStatement = collections.namedtuple(
107     'NormalExpressionStatement',
108     [
109         'expression',
110     ],
111 )
112
113 NormalAssignmentStatement = collections.namedtuple(
114     'NormalAssignmentStatement',
115     [
116         'target',
117         'expression',
118     ],
119 )
120
121 NormalIfElseStatement = collections.namedtuple(
122     'NormalIfElseStatement',
123     [
124         'condition_expression',
125         'if_statement_list',
126         'else_statement_list',
127     ],
128 )
129
130 NormalFunctionDefinitionStatement = collections.namedtuple(
131     'NormalFunctionDefinitionStatement',
132     [
133         'name',
134         'argument_name_list',
135         'statement_list',
136     ],
137 )
138
139 NormalProgram = collections.namedtuple(
140     'NormalProgram',
141     [
142         'statement_list',
143     ],
144 )
145
146 def fake_normalization(counter, thing):
147     return (counter, (), thing)
148
149 def normalize_integer_literal_expression(counter, expression):
150     variable = '${}'.format(counter)
151     return (
152         counter + 1,
153         (
154             NormalVariableInitializationStatement(
155                 variable=variable,
156                 expression=NormalIntegerLiteralExpression(integer=expression.integer),
157             ),
158         ),
159         NormalVariableExpression(variable=variable),
160     )
161
162 NormalListConstructExpression = collections.namedtuple(
163     'NormalListConstructExpression',
164     [
165         'allocate',
166     ],
167 )
168
169 NormalListAppendStatement = collections.namedtuple(
170     'NormalListAppendStatement',
171     [
172         'list_expression',
173         'item_expression',
174     ],
175 )
176
177 NormalListGetExpression = collections.namedtuple(
178     'NormalListGetExpression',
179     [
180         'list_expression',
181         'index_expression',
182     ],
183 )
184
185 def normalize_list_literal_expression(counter, expression):
186     list_variable = '${}'.format(counter)
187     counter += 1
188
189     prestatements = [
190         NormalVariableInitializationStatement(
191             variable=list_variable,
192             expression=NormalListConstructExpression(allocate=len(expression.item_expression_list)),
193         ),
194     ]
195
196     list_expression = NormalVariableExpression(variable=list_variable)
197
198     for item_expression in expression.item_expression_list:
199         counter, item_expression_prestatements, normalized = normalize_expression(
200             counter,
201             item_expression,
202         )
203
204         for p in item_expression_prestatements:
205             prestatements.append(p)
206
207         prestatements.append(
208             NormalListAppendStatement(
209                 list_expression=list_expression,
210                 item_expression=normalized,
211             )
212         )
213
214     return (
215         counter,
216         tuple(prestatements),
217         list_expression,
218     )
219
220 def normalize_list_item_expression(counter, expression):
221     counter, list_prestatements, list_expression = normalize_expression(counter, expression.list_expression)
222     counter, index_prestatements, index_expression = normalize_expression(counter, expression.index_expression)
223
224     result_variable = '${}'.format(counter)
225     result_prestatement = NormalVariableInitializationStatement(
226         variable=result_variable,
227         expression=NormalListGetExpression(
228             list_expression=list_expression,
229             index_expression=index_expression,
230         ),
231     )
232
233     return (
234         counter + 1,
235         list_prestatements + index_prestatements + (result_prestatement,),
236         NormalVariableExpression(variable=result_variable),
237     )
238
239 def normalize_string_literal_expression(counter, expression):
240     variable = '${}'.format(counter)
241     return (
242         counter + 1,
243         (
244             NormalVariableInitializationStatement(
245                 variable=variable,
246                 expression=NormalStringLiteralExpression(string=expression.string),
247             ),
248         ),
249         NormalVariableExpression(variable=variable),
250     )
251
252 NormalStructureLiteralExpression = collections.namedtuple(
253     'NormalStructureLiteralExpression',
254     [
255         'field_count',
256         'symbol_list_variable',
257         'value_list_variable',
258     ],
259 )
260
261 def normalize_structure_literal_expression(counter, expression):
262     prestatements = []
263     field_symbol_array = []
264     field_value_array = []
265
266     for symbol_expression_pair in expression.fields:
267         counter, field_prestatements, field_expression = normalize_expression(
268             counter,
269             symbol_expression_pair.expression,
270         )
271
272         for p in field_prestatements:
273             prestatements.append(p)
274
275         field_symbol_array.append(symbol_expression_pair.symbol)
276         field_value_array.append(field_expression)
277
278     symbol_array_variable = '${}'.format(counter)
279     counter += 1
280
281     prestatements.append(
282         NormalSymbolArrayVariableInitializationStatement(
283             variable=symbol_array_variable,
284             symbol_list=tuple(field_symbol_array),
285         )
286     )
287
288     value_array_variable = '${}'.format(counter)
289     counter += 1
290
291     prestatements.append(
292         NormalArrayVariableInitializationStatement(
293             variable=value_array_variable,
294             items=tuple(field_value_array),
295         )
296     )
297
298     variable = '${}'.format(counter)
299
300     prestatements.append(
301         NormalVariableInitializationStatement(
302             variable=variable,
303             expression=NormalStructureLiteralExpression(
304                 field_count=len(expression.fields),
305                 symbol_list_variable=symbol_array_variable,
306                 value_list_variable=value_array_variable,
307             ),
308         )
309     )
310
311     return (
312         counter + 1,
313         tuple(prestatements),
314         NormalVariableExpression(variable=variable),
315     )
316
317
318 def normalize_symbol_expression(counter, expression):
319     variable = '${}'.format(counter)
320     return (
321         counter + 1,
322         (
323             NormalVariableInitializationStatement(
324                 variable=variable,
325                 expression=NormalSymbolExpression(symbol=expression.symbol),
326             ),
327         ),
328         NormalVariableExpression(variable=variable),
329     )
330
331 def normalize_function_call_expression(counter, expression):
332     assert isinstance(expression, parsing.FurFunctionCallExpression)
333
334     prestatements = []
335
336     for argument in expression.arguments:
337         counter, argument_prestatements, normalized_argument = normalize_expression(counter, argument)
338
339         for s in argument_prestatements:
340             prestatements.append(s)
341
342         variable = '${}'.format(counter)
343         prestatements.append(
344             NormalVariableInitializationStatement(
345                 variable=variable,
346                 expression=normalized_argument,
347             )
348         )
349         prestatements.append(
350             NormalPushStatement(
351                 expression=NormalVariableExpression(
352                     variable=variable,
353                 ),
354             ),
355         )
356         counter += 1
357
358     counter, function_prestatements, function_expression = normalize_expression(
359         counter,
360         expression.function,
361     )
362
363     for ps in function_prestatements:
364         prestatements.append(ps)
365
366     if not isinstance(function_expression, NormalVariableExpression):
367         function_variable = '${}'.format(counter)
368
369         prestatements.append(
370             NormalVariableInitializationStatement(
371                 variable=function_variable,
372                 expression=function_expression,
373             )
374         )
375
376         function_expression = NormalVariableExpression(variable=function_variable)
377         counter += 1
378
379     result_variable = '${}'.format(counter)
380
381     prestatements.append(
382         NormalVariableInitializationStatement(
383             variable=result_variable,
384             expression=NormalFunctionCallExpression(
385                 function_expression=function_expression,
386                 argument_count=len(expression.arguments),
387             ),
388         )
389     )
390
391     return (
392         counter + 1,
393         tuple(prestatements),
394         NormalVariableExpression(variable=result_variable),
395     )
396
397 def normalize_basic_infix_operation(counter, expression):
398     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left)
399     counter, right_prestatements, right_expression = normalize_expression(counter, expression.right)
400
401     left_variable = '${}'.format(counter)
402     counter += 1
403     right_variable = '${}'.format(counter)
404     counter += 1
405     center_variable = '${}'.format(counter)
406     counter += 1
407
408     root_prestatements = (
409         NormalVariableInitializationStatement(
410             variable=left_variable,
411             expression=left_expression,
412         ),
413         NormalVariableInitializationStatement(
414             variable=right_variable,
415             expression=right_expression,
416         ),
417         NormalVariableInitializationStatement(
418             variable=center_variable,
419             expression=NormalInfixExpression(
420                 order=expression.order,
421                 operator=expression.operator,
422                 left=NormalVariableExpression(variable=left_variable),
423                 right=NormalVariableExpression(variable=right_variable),
424             ),
425         ),
426     )
427
428     return (
429         counter,
430         left_prestatements + right_prestatements + root_prestatements,
431         NormalVariableExpression(variable=center_variable),
432     )
433
434 def desugar_ternary_comparison(counter, expression):
435     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left.left)
436     counter, middle_prestatements, middle_expression = normalize_expression(counter, expression.left.right)
437
438     left_variable = '${}'.format(counter)
439     counter += 1
440     middle_variable = '${}'.format(counter)
441     counter += 1
442
443     juncture_prestatements = (
444         NormalVariableInitializationStatement(
445             variable=left_variable,
446             expression=left_expression,
447         ),
448         NormalVariableInitializationStatement(
449             variable=middle_variable,
450             expression=middle_expression,
451         )
452     )
453
454     counter, boolean_expression_prestatements, boolean_expression =  normalize_boolean_expression(
455         counter,
456         parsing.FurInfixExpression(
457             order='and_level',
458             operator='and',
459             left=parsing.FurInfixExpression(
460                 order='comparison_level',
461                 operator=expression.left.operator,
462                 left=NormalVariableExpression(variable=left_variable),
463                 right=NormalVariableExpression(variable=middle_variable),
464             ),
465             right=parsing.FurInfixExpression(
466                 order='comparison_level',
467                 operator=expression.operator,
468                 left=NormalVariableExpression(variable=middle_variable),
469                 right=expression.right,
470             ),
471         )
472     )
473
474     return (
475         counter,
476         left_prestatements + middle_prestatements + juncture_prestatements + boolean_expression_prestatements,
477         boolean_expression,
478     )
479
480 def normalize_comparison_expression(counter, expression):
481     if isinstance(expression.left, parsing.FurInfixExpression) and expression.order == 'comparison_level':
482         return desugar_ternary_comparison(counter, expression)
483
484     return normalize_basic_infix_operation(counter, expression)
485
486 def normalize_boolean_expression(counter, expression):
487     counter, left_prestatements, left_expression = normalize_expression(counter, expression.left)
488     counter, right_prestatements, right_expression = normalize_expression(counter, expression.right)
489
490     result_variable = '${}'.format(counter)
491     if_else_prestatment = NormalVariableInitializationStatement(
492         variable=result_variable,
493         expression=left_expression,
494     )
495     counter += 1
496
497     condition_expression=NormalVariableExpression(variable=result_variable)
498     short_circuited_statements = right_prestatements + (NormalVariableReassignmentStatement(variable=result_variable, expression=right_expression),)
499
500     if expression.operator == 'and':
501         if_else_statement = NormalIfElseStatement(
502             condition_expression=condition_expression,
503             if_statement_list=short_circuited_statements,
504             else_statement_list=(),
505         )
506
507     elif expression.operator == 'or':
508         if_else_statement = NormalIfElseStatement(
509             condition_expression=condition_expression,
510             if_statement_list=(),
511             else_statement_list=short_circuited_statements,
512         )
513
514     else:
515         raise Exception('Unable to handle operator "{}"'.format(expression.operator))
516
517     return (
518         counter,
519         left_prestatements + (if_else_prestatment, if_else_statement),
520         NormalVariableExpression(variable=result_variable),
521     )
522
523 def normalize_dot_expression(counter, expression):
524     assert isinstance(expression.right, parsing.FurSymbolExpression)
525
526     counter, prestatements, left_expression = normalize_expression(counter, expression.left)
527
528     variable = '${}'.format(counter)
529
530     dot_expression_prestatement = NormalVariableInitializationStatement(
531         variable=variable,
532         expression=NormalDotExpression(
533             instance=left_expression,
534             field=expression.right.symbol,
535         ),
536     )
537
538     return (
539         counter + 1,
540         prestatements + (dot_expression_prestatement,),
541         NormalVariableExpression(variable=variable),
542     )
543
544 def normalize_infix_expression(counter, expression):
545     return {
546         'multiplication_level': normalize_basic_infix_operation,
547         'addition_level': normalize_basic_infix_operation,
548         'comparison_level': normalize_comparison_expression,
549         'dot_level': normalize_dot_expression,
550         'and_level': normalize_boolean_expression,
551         'or_level': normalize_boolean_expression,
552     }[expression.order](counter, expression)
553
554 def normalize_if_expression(counter, expression):
555     counter, condition_prestatements, condition_expression = normalize_expression(
556         counter,
557         expression.condition_expression,
558     )
559
560     result_variable = '${}'.format(counter)
561     counter += 1
562
563     counter, if_statement_list = normalize_statement_list(
564         counter,
565         expression.if_statement_list,
566         assign_result_to=result_variable,
567     )
568     counter, else_statement_list = normalize_statement_list(
569         counter,
570         expression.else_statement_list,
571         assign_result_to=result_variable,
572     )
573
574     return (
575         counter,
576         condition_prestatements + (
577             NormalVariableInitializationStatement(
578                 variable=result_variable,
579                 expression=NormalVariableExpression(variable='builtin$nil'),
580             ),
581             NormalIfElseStatement(
582                 condition_expression=condition_expression,
583                 if_statement_list=if_statement_list,
584                 else_statement_list=else_statement_list,
585             ),
586         ),
587         NormalVariableExpression(variable=result_variable),
588     )
589
590 def normalize_negation_expression(counter, expression):
591     counter, prestatements, internal_expression = normalize_expression(counter, expression.value)
592
593     internal_variable = '${}'.format(counter)
594     counter += 1
595
596     return (
597         counter,
598         prestatements + (
599             NormalVariableInitializationStatement(
600                 variable=internal_variable,
601                 expression=internal_expression,
602             ),
603         ),
604         NormalNegationExpression(internal_expression=NormalVariableExpression(variable=internal_variable)),
605     )
606
607 def normalize_expression(counter, expression):
608     return {
609         NormalInfixExpression: fake_normalization,
610         NormalVariableExpression: fake_normalization,
611         parsing.FurFunctionCallExpression: normalize_function_call_expression,
612         parsing.FurIfExpression: normalize_if_expression,
613         parsing.FurInfixExpression: normalize_infix_expression,
614         parsing.FurIntegerLiteralExpression: normalize_integer_literal_expression,
615         parsing.FurListLiteralExpression: normalize_list_literal_expression,
616         parsing.FurListItemExpression: normalize_list_item_expression,
617         parsing.FurNegationExpression: normalize_negation_expression,
618         parsing.FurStringLiteralExpression: normalize_string_literal_expression,
619         parsing.FurStructureLiteralExpression: normalize_structure_literal_expression,
620         parsing.FurSymbolExpression: normalize_symbol_expression,
621     }[type(expression)](counter, expression)
622
623 def normalize_expression_statement(counter, statement):
624     # TODO Normalized will be a NormalVariableExpression, which will go unused
625     # for expression statements in every case except when it's a return
626     # statement. This cases warnings on C compilation. We should only generate
627     # this variable when it will be used on return.
628     counter, prestatements, normalized = normalize_expression(counter, statement.expression)
629
630     return (
631         counter,
632         prestatements,
633         NormalExpressionStatement(expression=normalized),
634     )
635
636 def normalize_function_definition_statement(counter, statement):
637     _, statement_list = normalize_statement_list(
638         0,
639         statement.statement_list,
640         assign_result_to='result',
641     )
642     return (
643         counter,
644         (),
645         NormalFunctionDefinitionStatement(
646             name=statement.name,
647             argument_name_list=statement.argument_name_list,
648             statement_list=statement_list,
649         ),
650     )
651
652 def normalize_assignment_statement(counter, statement):
653     counter, prestatements, normalized_expression = normalize_expression(counter, statement.expression)
654     return (
655         counter,
656         prestatements,
657         NormalAssignmentStatement(
658             target=statement.target,
659             expression=normalized_expression,
660         ),
661     )
662
663 def normalize_statement(counter, statement):
664     return {
665         parsing.FurAssignmentStatement: normalize_assignment_statement,
666         parsing.FurExpressionStatement: normalize_expression_statement,
667         parsing.FurFunctionDefinitionStatement: normalize_function_definition_statement,
668     }[type(statement)](counter, statement)
669
670 @util.force_generator(tuple)
671 def normalize_statement_list(counter, statement_list, **kwargs):
672     assign_result_to = kwargs.pop('assign_result_to', None)
673
674     assert len(kwargs) == 0
675
676     result_statement_list = []
677
678     for statement in statement_list:
679         counter, prestatements, normalized = normalize_statement(counter, statement)
680         for s in prestatements:
681             result_statement_list.append(s)
682         result_statement_list.append(normalized)
683
684     # TODO The way we fix the last statement is really confusing
685     last_statement = result_statement_list[-1]
686
687     if isinstance(last_statement, NormalExpressionStatement) and isinstance(last_statement.expression, NormalVariableExpression):
688         if assign_result_to is not None:
689             result_expression = result_statement_list.pop().expression
690             result_statement_list.append(
691                 NormalVariableReassignmentStatement(
692                     variable=assign_result_to,
693                     expression=result_expression,
694                 )
695             )
696
697     return (
698         counter,
699         result_statement_list,
700     )
701
702 def normalize(program):
703     _, statement_list = normalize_statement_list(0, program.statement_list)
704
705     return NormalProgram(
706         statement_list=statement_list,
707     )