b75b3a0fa522df7d504c814896e05be635bc03ef
[fur] / transformation.py
1 import collections
2
3 import parsing
4
5 CIntegerLiteral = collections.namedtuple(
6     'CIntegerLiteral',
7     [
8         'value',
9     ],
10 )
11
12 CStringLiteral = collections.namedtuple(
13     'CStringLiteral',
14     [
15         'index',
16         'value',
17     ],
18 )
19
20 CConstantExpression = collections.namedtuple(
21     'CConstantExpression',
22     [
23         'value'
24     ],
25 )
26
27 CSymbolExpression = collections.namedtuple(
28     'CSymbolExpression',
29     [
30         'symbol',
31         'symbol_list_index',
32     ],
33 )
34
35 CNegationExpression = collections.namedtuple(
36     'CNegationExpression',
37     [
38         'value',
39     ],
40 )
41
42 CFunctionCallForFurInfixOperator = collections.namedtuple(
43     'CFunctionCallForFurInfixOperator',
44     [
45         'name',
46         'left',
47         'right',
48     ],
49 )
50
51 CFunctionCallExpression = collections.namedtuple(
52     'CFunctionCallExpression',
53     [
54         'name',
55         'arguments',
56     ],
57 )
58
59 CAssignmentStatement = collections.namedtuple(
60     'CAssignmentStatement',
61     [
62         'target',
63         'target_symbol_list_index',
64         'expression',
65     ],
66 )
67
68 CProgram = collections.namedtuple(
69     'CProgram',
70     [
71         'builtin_set',
72         'statements',
73         'standard_libraries',
74         'string_literal_list',
75         'symbol_list',
76     ],
77 )
78
79 EQUALITY_LEVEL_OPERATOR_TO_FUNCTION_NAME_MAPPING = {
80     '==':   'equals',
81     '!=':   'notEquals',
82     '<=':   'lessThanOrEqual',
83     '>=':   'greaterThanOrEqual',
84     '<':    'lessThan',
85     '>':    'greaterThan',
86 }
87
88 def transform_equality_level_expression(accumulators, expression):
89     # Transform expressions like 1 < 2 < 3 into expressions like 1 < 2 && 2 < 3
90     if isinstance(expression.left, parsing.FurInfixExpression) and expression.left.order == 'equality_level':
91         left = transform_equality_level_expression(
92             accumulators,
93             expression.left
94         )
95
96         middle = left.right
97
98         right = transform_expression(
99             accumulators,
100             expression.right,
101         )
102
103         # TODO Don't evaluate the middle expression twice
104         return CFunctionCallForFurInfixOperator(
105             name='and',
106             left=left,
107             right=CFunctionCallForFurInfixOperator(
108                 name=EQUALITY_LEVEL_OPERATOR_TO_FUNCTION_NAME_MAPPING[expression.operator],
109                 left=middle,
110                 right=right,
111             ),
112         )
113
114     return CFunctionCallForFurInfixOperator(
115         name=EQUALITY_LEVEL_OPERATOR_TO_FUNCTION_NAME_MAPPING[expression.operator],
116         left=transform_expression(accumulators, expression.left),
117         right=transform_expression(accumulators, expression.right),
118     )
119
120 BUILTINS = {
121     'false':    [],
122     'pow':      ['math.h'],
123     'print':    ['stdio.h'],
124     'true':     [],
125 }
126
127 def transform_expression(accumulators, expression):
128     if isinstance(expression, parsing.FurParenthesizedExpression):
129         # Parentheses can be removed because everything in the C output is explicitly parenthesized
130         return transform_expression(accumulators, expression.internal)
131
132     if isinstance(expression, parsing.FurNegationExpression):
133         return transform_negation_expression(accumulators, expression)
134
135     if isinstance(expression, parsing.FurFunctionCallExpression):
136         return transform_function_call_expression(accumulators, expression)
137
138     if isinstance(expression, parsing.FurSymbolExpression):
139         if expression.value in ['true', 'false']:
140             return CConstantExpression(value=expression.value)
141
142         if expression.value not in accumulators.symbol_list:
143             symbol_list.append(expression.value)
144
145         return CSymbolExpression(
146             symbol=expression.value,
147             symbol_list_index=accumulators.symbol_list.index(expression.value),
148         )
149
150     if isinstance(expression, parsing.FurStringLiteralExpression):
151         value = expression.value
152
153         try:
154             index = accumulators.string_literal_list.index(value)
155         except ValueError:
156             index = len(accumulators.string_literal_list)
157             accumulators.string_literal_list.append(value)
158
159         return CStringLiteral(index=index, value=value)
160
161     LITERAL_TYPE_MAPPING = {
162         parsing.FurIntegerLiteralExpression: CIntegerLiteral,
163     }
164
165     if type(expression) in LITERAL_TYPE_MAPPING:
166         return LITERAL_TYPE_MAPPING[type(expression)](value=expression.value)
167
168     if isinstance(expression, parsing.FurInfixExpression):
169         if expression.order == 'equality_level':
170             return transform_equality_level_expression(accumulators, expression)
171
172         INFIX_OPERATOR_TO_FUNCTION_NAME = {
173             '+':    'add',
174             '-':    'subtract',
175             '*':    'multiply',
176             '//':   'integerDivide',
177             '%':    'modularDivide',
178             'and':  'and',
179             'or':   'or',
180         }
181
182         return CFunctionCallForFurInfixOperator(
183             name=INFIX_OPERATOR_TO_FUNCTION_NAME[expression.operator],
184             left=transform_expression(accumulators, expression.left),
185             right=transform_expression(accumulators, expression.right),
186         )
187
188     raise Exception('Could not transform expression "{}"'.format(expression))
189
190 def transform_assignment_statement(accumulators, assignment_statement):
191     # TODO Check that target is not a builtin
192     if assignment_statement.target not in accumulators.symbol_list:
193         accumulators.symbol_list.append(assignment_statement.target)
194
195     return CAssignmentStatement(
196         target=assignment_statement.target,
197         target_symbol_list_index=accumulators.symbol_list.index(assignment_statement.target),
198         expression=transform_expression(
199             accumulators,
200             assignment_statement.expression,
201         ),
202     )
203
204 def transform_negation_expression(accumulators, negation_expression):
205     return CNegationExpression(
206         value=transform_expression(accumulators, negation_expression.value),
207     )
208
209 def transform_function_call_expression(accumulators, function_call):
210     if function_call.function.value in BUILTINS.keys():
211         # TODO Check that the builtin is actually callable
212         accumulators.builtin_set.add(function_call.function.value)
213
214         return CFunctionCallExpression(
215             name='builtin$' + function_call.function.value,
216             arguments=tuple(
217                 transform_expression(accumulators, arg)
218                 for arg in function_call.arguments
219             ),
220         )
221
222     raise Exception()
223
224 def transform_statement(accumulators, statement):
225     return {
226         parsing.FurAssignmentStatement: transform_assignment_statement,
227         parsing.FurFunctionCallExpression: transform_function_call_expression,
228     }[type(statement)](accumulators, statement)
229
230
231 Accumulators = collections.namedtuple(
232     'Accumulators',
233     [
234         'builtin_set',
235         'symbol_list',
236         'string_literal_list',
237     ],
238 )
239
240 def transform(program):
241     accumulators = Accumulators(
242         builtin_set=set(),
243         symbol_list=[],
244         string_literal_list=[],
245     )
246
247     c_statements = [
248         transform_statement(accumulators, statement) for statement in program.statement_list
249     ]
250
251     standard_libraries = set()
252     for builtin in accumulators.builtin_set:
253         for standard_library in BUILTINS[builtin]:
254             standard_libraries.add(standard_library)
255
256     return CProgram(
257         builtin_set=accumulators.builtin_set,
258         statements=c_statements,
259         standard_libraries=standard_libraries,
260         string_literal_list=accumulators.string_literal_list,
261         symbol_list=accumulators.symbol_list,
262     )
263
264
265 if __name__ == '__main__':
266     import unittest
267
268     unittest.main()