Added support for comparison operators
[fur] / transformation.py
index e3658bc..cd9d18d 100644 (file)
@@ -16,6 +16,13 @@ CStringLiteral = collections.namedtuple(
     ],
 )
 
+CConstantExpression = collections.namedtuple(
+    'CConstantExpression',
+    [
+        'value'
+    ],
+)
+
 CSymbolExpression = collections.namedtuple(
     'CSymbolExpression',
     [
@@ -63,6 +70,63 @@ CIntegerDivisionExpression = collections.namedtuple(
     ],
 )
 
+CEqualityExpression = collections.namedtuple(
+    'CEqualityExpression',
+    [
+        'left',
+        'right',
+    ],
+)
+
+CInequalityExpression = collections.namedtuple(
+    'CInequalityExpression',
+    [
+        'left',
+        'right',
+    ],
+)
+
+CGreaterThanOrEqualExpression = collections.namedtuple(
+    'CGreaterThanOrEqualExpression',
+    [
+        'left',
+        'right',
+    ],
+)
+
+CLessThanOrEqualExpression = collections.namedtuple(
+    'CLessThanOrEqualExpression',
+    [
+        'left',
+        'right',
+    ],
+)
+
+CGreaterThanExpression = collections.namedtuple(
+    'CGreaterThanExpression',
+    [
+        'left',
+        'right',
+    ],
+)
+
+CLessThanExpression = collections.namedtuple(
+    'CLessThanExpression',
+    [
+        'left',
+        'right',
+    ],
+)
+
+CAndExpression = collections.namedtuple(
+    'CAndExpression',
+    [
+        'left',
+        'right',
+    ],
+)
+
+
 CModularDivisionExpression = collections.namedtuple(
     'CModularDivisionExpression',
     [
@@ -98,12 +162,58 @@ CProgram = collections.namedtuple(
     ],
 )
 
+EQUALITY_LEVEL_TYPE_MAPPING = {
+    parsing.FurEqualityExpression: CEqualityExpression,
+    parsing.FurInequalityExpression: CInequalityExpression,
+    parsing.FurLessThanOrEqualExpression: CLessThanOrEqualExpression,
+    parsing.FurGreaterThanOrEqualExpression: CGreaterThanOrEqualExpression,
+    parsing.FurLessThanExpression: CLessThanExpression,
+    parsing.FurGreaterThanExpression: CGreaterThanExpression,
+}
+
+def transform_equality_level_expression(builtin_dependencies, symbol_list, expression):
+    # Transform expressions like 1 < 2 < 3 into expressions like 1 < 2 && 2 < 3
+    if type(expression.left) in EQUALITY_LEVEL_TYPE_MAPPING:
+        left = transform_equality_level_expression(
+            builtin_dependencies,
+            symbol_list,
+            expression.left
+        )
+
+        middle = left.right
+
+        right = transform_expression(
+            builtin_dependencies,
+            symbol_list,
+            expression.right,
+        )
+
+        # TODO Don't evaluate the middle expression twice
+        return CAndExpression(
+            left=left,
+            right=EQUALITY_LEVEL_TYPE_MAPPING[type(expression)](
+                left=middle,
+                right=right,
+            ),
+        )
+
+    return EQUALITY_LEVEL_TYPE_MAPPING[type(expression)](
+        left=transform_expression(builtin_dependencies, symbol_list, expression.left),
+        right=transform_expression(builtin_dependencies, symbol_list, expression.right),
+    )
+
 BUILTINS = {
+    'false':    [],
     'pow':      ['math.h'],
     'print':    ['stdio.h'],
+    'true':     [],
 }
 
 def transform_expression(builtin_dependencies, symbol_list, expression):
+    if isinstance(expression, parsing.FurParenthesizedExpression):
+        # Parentheses can be removed because everything in the C output is explicitly parenthesized
+        return transform_expression(builtin_dependencies, symbol_list, expression.internal)
+
     if isinstance(expression, parsing.FurNegationExpression):
         return transform_negation_expression(builtin_dependencies, symbol_list, expression)
 
@@ -111,6 +221,9 @@ def transform_expression(builtin_dependencies, symbol_list, expression):
         return transform_function_call_expression(builtin_dependencies, symbol_list, expression)
 
     if isinstance(expression, parsing.FurSymbolExpression):
+        if expression.value in ['true', 'false']:
+            return CConstantExpression(value=expression.value)
+
         if expression.value not in symbol_list:
             symbol_list.append(expression.value)
 
@@ -127,6 +240,9 @@ def transform_expression(builtin_dependencies, symbol_list, expression):
     if type(expression) in LITERAL_TYPE_MAPPING:
         return LITERAL_TYPE_MAPPING[type(expression)](value=expression.value)
 
+    if type(expression) in EQUALITY_LEVEL_TYPE_MAPPING:
+        return transform_equality_level_expression(builtin_dependencies, symbol_list, expression)
+
     INFIX_TYPE_MAPPING = {
         parsing.FurAdditionExpression: CAdditionExpression,
         parsing.FurSubtractionExpression: CSubtractionExpression,
@@ -162,6 +278,7 @@ def transform_negation_expression(builtin_dependencies, symbol_list, negation_ex
 
 def transform_function_call_expression(builtin_dependencies, symbol_list, function_call):
     if function_call.function.value in BUILTINS.keys():
+        # TODO Check that the builtin is actually callable
         builtin_dependencies.add(function_call.function.value)
 
         return CFunctionCallExpression(