import typing
from z3 import *
from frame.tools.grammar import Grammar
from collections import namedtuple



class Z3PythonicTranslation(object):
    """
    This class is a DSL for writing Z3 constraints in a more Pythonic way.
    """
    def __init__(self):
        self._z3_pythonic_dsl_solver = Solver()
        self._z3_pythonic_dsl_variable_map = {}
        self._z3_pythonic_dsl_nat_var_prefix = "z3_dsl_nat_var_"
        self._z3_pythonic_dsl_nat_var_counter = 0
        self._z3_pythonic_dsl_nat_set = SetSort(IntSort())
        self._NatSet = Const("z3_pythonic_dsl_nat_set", self._z3_pythonic_dsl_nat_set)
        self._z3_pythonic_dsl_x = Int(self._z3_pythonic_dsl_nat_var_prefix + "_x")
        self._z3_pythonic_dsl_solver.add(
            ForAll(
                [self._z3_pythonic_dsl_x],
                And(
                    Implies(
                        IsMember(self._z3_pythonic_dsl_x, self._NatSet),
                        self._z3_pythonic_dsl_x >= 0,
                    ),
                    Implies(
                        self._z3_pythonic_dsl_x >= 0,
                        IsMember(self._z3_pythonic_dsl_x, self._NatSet),
                    ),
                ),
            )
        )
        self._z3_pythonic_dsl_solver.push()

    def NatVar(self):
        n = Int(
            self._z3_pythonic_dsl_nat_var_prefix
            + str(self._z3_pythonic_dsl_nat_var_counter)
        )
        self._z3_pythonic_dsl_solver.add(n >= 0)
        self._z3_pythonic_dsl_solver.add(IsMember(n, self._NatSet))
        self._z3_pythonic_dsl_nat_var_counter += 1
        return n

    def Nat(self, name):
        if name in self._z3_pythonic_dsl_variable_map:
            return self._z3_pythonic_dsl_variable_map[name]
        n = Int(name)
        self._z3_pythonic_dsl_solver.add(n >= 0)
        self._z3_pythonic_dsl_solver.add(IsMember(n, self._NatSet))
        self._z3_pythonic_dsl_variable_map[name] = n
        return n

    def prove(
        self, contraints, theorem, timeout=None
    ) -> typing.Tuple[bool, typing.Optional[str]]:
        exec_locals = {
            "NatVar": self.NatVar,
            "NatSet": self._NatSet,
            "z3_pythonic_dsl_solver": self._z3_pythonic_dsl_solver,
        }
        code = f"{contraints}\nz3_pythonic_dsl_solver.add(Not({theorem}))"
        self._z3_pythonic_dsl_solver.push()
        exec(code, globals(), exec_locals)
        if timeout is not None:
            self._z3_pythonic_dsl_solver.set(timeout=timeout)
        z3_pythonic_dsl_result = self._z3_pythonic_dsl_solver.check()
        proof_found = z3_pythonic_dsl_result == z3.unsat
        counter_example = None
        if not proof_found:
            counter_example = self._z3_pythonic_dsl_solver.model()
        self._z3_pythonic_dsl_solver.pop()
        self._z3_pythonic_dsl_variable_map.clear()
        self._z3_pythonic_dsl_nat_var_counter = 0
        return proof_found, counter_example


class Z3ComposableDsl(Grammar):
    Z3TranslationGrammarResult = namedtuple(
        "Z3TranslationGrammarResult",
        [
            "is_logical",
            "is_arith",
            "proved",
            "counter_example",
            "timed_out",
            "constraints",
            "expr",
        ],
    )
    SubExpr = namedtuple("SubExpr", ["simple_id", "var_str_id", "ctx_fun"])

    class Ctx:
        def __init__(self):
            self.simple_id_to_var_id: typing.OrderedDict[int, str] = {}
            self.var_id_to_simple_id: typing.OrderedDict[str, int] = {}
            self.simple_id_to_alias_id: typing.OrderedDict[int, int] = {}
            self.var_count = 0
            self.main_var_id: typing.Optional[str] = None
            self.logical_expr: typing.Callable[[Z3ComposableDsl.Ctx], str] = None
            self.arith_expr: typing.Callable[[Z3ComposableDsl.Ctx], str] = None
            self.declarations: typing.OrderedDict[
                str, typing.Callable[[Z3ComposableDsl.Ctx], str]
            ] = {}
            self.sub_programs: typing.List[Z3ComposableDsl.Ctx] = []
            self.variables_declared_in_order: typing.List[int] = []
            self.pending_arith_sub_expressions: typing.List[Z3ComposableDsl.Ctx] = []
            self.pending_logical_sub_expressions: typing.List[Z3ComposableDsl.Ctx] = []
            self.params: typing.List[int] = []
            self.params_map: typing.OrderedDict[int, int] = {}

        def __str__(self):
            vars_declared = [
                self.simple_id_to_var_id[i] for i in self.variables_declared_in_order
            ]
            declartions = [
                self.declarations[var_dec](self) for var_dec in vars_declared
            ]
            declartions = "\n".join(declartions)
            arith = self.arith_expr(self) if self.arith_expr is not None else None
            logical = self.logical_expr(self) if self.logical_expr is not None else None
            assert arith is not None or logical is not None
            expr = arith if arith is not None else logical
            return f"""{declartions}\nReturn {expr}"""
        
        def __repr__(self):
            return self.__str__()

    grammar = """
Prog: 
    LogicalProg
|   ArithProg;

LogicalProg:
    Declarations Return LogicalExpr;

ArithProg:
    Declarations Return ArithExpr;

Declarations:
    Declaration Declarations
|   EMPTY;

ArithExpr:
  num
| var
| ArithExpr '+' ArithExpr {left, 1}
| ArithExpr '-' ArithExpr {left, 1}
| ArithExpr '*' ArithExpr {left, 2}
| ArithExpr '/' ArithExpr {left, 2}
| ArithExpr '%' ArithExpr {left, 3}
| '('ArithExpr')';

LogicalExpr:
    ArithExpr '==' ArithExpr {left, 1}
|   ArithExpr '!=' ArithExpr {left, 1}
|   ArithExpr '>' ArithExpr  {left, 1}
|   ArithExpr '<' ArithExpr  {left, 1}
|   ArithExpr '>=' ArithExpr {left, 1}
|   ArithExpr '<=' ArithExpr {left, 1}
|  'And' '(' LogicalExpr ',' LogicalExpr ')'
|  'Or' '(' LogicalExpr ',' LogicalExpr ')'
|  'And' '(' var ',' LogicalExpr ')'
|  'Or' '(' var ',' LogicalExpr ')'
|  'And' '(' LogicalExpr ',' var ')'
|  'Or' '(' LogicalExpr ',' var ')'
|  'And' '(' var ',' var ')'
|  'Or' '(' var ',' var ')'
|  'Implies' '(' LogicalExpr ',' LogicalExpr ')'
|  'Implies' '(' var ',' LogicalExpr ')'
|  'Implies' '(' LogicalExpr ',' var ')'
|  'Implies' '(' var ',' var ')'
|  'ForAll' '(' var ',' LogicalExpr ')'
|  'Exists' '(' var ',' LogicalExpr ')'
|  'IsMember' '(' ArithExpr ',' NatSet ')'
|  'Not' '(' LogicalExpr ')'
|  'Not' '(' var ')'
|   True
|   False;

Declaration:
   var ':=' NatVar '(' ')' EndLine
|  var ':=' Param EndLine
|  var ':=' ArithExpr EndLine
|  var ':=' LogicalExpr EndLine
|  var ':=' Exec'(' Prog ')' EndLine;

VarSeq:
    var
|   var ',' VarSeq;


terminals
Param: "Param";
NatSet: "NatSet";
NatVar: "NatVar";
EndLine: ";";
Call: "Call";
True: "True";
False: "False";
Dot: ".";
Nat: "Nat";
Exec: "Exec";
Quote: "'";
Return: "return";
var: /[a-zA-Z_][a-zA-Z0-9_]*/;
num: /[-]*[0-9]+/;
"""
    keywords = [
        "Exec",
        "Return",
        ";",
        "'" "And",
        "Or",
        "Implies",
        "ForAll",
        "Exists",
        "NatSet",
        "NatVar",
        "Nat",
        "IsMember",
        "Not",
        "Param",
        "z3_dsl_nat_var_",
        "z3_pythonic_dsl_solver",
        "z3_pythonic_dsl_nat_var_counter",
        "z3_pythonic_dsl_nat_var_prefix",
        "z3_pythonic_dsl_variable_map",
        "z3_pythonic_dsl_nat_set",
        "z3_pythonic_dsl_result",
        "local_",
        "promoted_",
        "Call",
        "call",
        "True",
        "False",
        "true",
        "false",
        "Int",
        "int",
        "Bool",
        "bool"
    ]

    def __init__(self):
        super(Z3ComposableDsl, self).__init__(
            Z3ComposableDsl.grammar, Z3ComposableDsl.keywords
        )

    def _parse_prog(self, nodes, context: Ctx) -> Ctx:
        new_context = nodes[0]
        assert isinstance(new_context, Z3ComposableDsl.Ctx)
        assert (
            new_context.logical_expr is not None or new_context.arith_expr is not None
        )
        assert (
            len(new_context.sub_programs) == 0
        ), "Compiler Bug: Sub programs should be empty at this point."
        assert (
            len(new_context.pending_arith_sub_expressions) == 0
        ), "Compiler Bug: Pending sub expressions should be empty at this point."
        assert (
            len(new_context.pending_logical_sub_expressions) == 0
        ), "Compiler Bug: Pending sub expressions should be empty at this point."
        vars_in_scope = set(new_context.var_id_to_simple_id.keys())
        declarations_in_prog = set(new_context.declarations.keys())
        # All variables in the program should be declared
        for var in vars_in_scope:
            if var not in declarations_in_prog:
                # Check if the var is just an alias
                # The variable is not an alias, and not declared but still in scope
                raise Exception(
                    f"Compiler Bug: Variable '{var}' is neither declared nor an alias."
                    f"Some bug with registering the variable."
                )
        # Check if all the variables in order are declared
        for var in new_context.variables_declared_in_order:
            if var not in new_context.simple_id_to_var_id:
                raise Exception(
                    f"Compiler Bug: Variable '{var}' is not declared in the program."
                    f"Some bug with registering the variable."
                )
            var_str_id = new_context.simple_id_to_var_id[var]
            if var_str_id not in new_context.declarations:
                raise Exception(
                    f"Compiler Bug: Variable '{var_str_id}' is not declared in the program. Some buug with registering the variable."
                )
        for simple_id, var_str_id in new_context.simple_id_to_var_id.items():
            if var_str_id not in new_context.var_id_to_simple_id:
                raise Exception(
                    f"Compiler Bug: Variable '{var_str_id}' is not declared in the program. Some buug with registering the variable."
                )

        if new_context.logical_expr is not None:
            assert new_context.arith_expr is None
            return new_context
        elif new_context.arith_expr is not None:
            assert new_context.logical_expr is None
            return new_context
        else:
            raise Exception(
                "Something went wrong in parsing both logical and arithmetic expressions are None"
            )

    def _merge_sub_expr_without_declaration_ctx(
        self, declaration_ctx: Ctx, sub_expr: Ctx
    ) -> Ctx:
        for simple_id, var_str_id in sub_expr.simple_id_to_var_id.items():
            local_alias_id = declaration_ctx.var_id_to_simple_id.get(var_str_id, None)
            if local_alias_id is None:
                raise ValueError(
                    f"User Error: Variable '{var_str_id}' is used in return expression, but not declared."
                )
                # # Just track the alias in the declaration context
                # # This can be some variable yet to be declared in some other part of the program
                # declaration_ctx.simple_id_to_alias_id[simple_id] = simple_id
            else:
                declaration_ctx.simple_id_to_alias_id[simple_id] = local_alias_id

    def _merge_declaration_ctx_with_expr_ctx(
        self, declaration_ctx: Ctx, expr_ctx: Ctx
    ) -> Ctx:
        # Merge the declarations and the expression by creating aliases
        declaration_ctx.arith_expr = expr_ctx.arith_expr
        declaration_ctx.logical_expr = expr_ctx.logical_expr
        self._merge_sub_expr_without_declaration_ctx(declaration_ctx, expr_ctx)

    def _parse_arith_prog(self, nodes, context: Ctx) -> Ctx:
        assert len(nodes) == 3
        declaration_ctx = nodes[0]
        assert isinstance(declaration_ctx, Z3ComposableDsl.Ctx)
        expr_ctx = nodes[2]
        assert isinstance(expr_ctx, Z3ComposableDsl.Ctx)
        assert (
            declaration_ctx.logical_expr is None or declaration_ctx.arith_expr is None
        )
        assert expr_ctx.logical_expr is None or expr_ctx.arith_expr is not None
        while len(declaration_ctx.pending_arith_sub_expressions) > 0:
            # Merge all the pending sub expressions
            sub_expr = declaration_ctx.pending_arith_sub_expressions.pop(0)
            assert isinstance(sub_expr, Z3ComposableDsl.Ctx)
            # Merge the declarations and the expression by creating aliases
            self._merge_sub_expr_without_declaration_ctx(declaration_ctx, sub_expr)
        # Optimization: We ignore all the pending logical sub expressions because
        # they cannot be used in the arithmetic expression return
        declaration_ctx.pending_logical_sub_expressions.clear()
        # Merge the declarations and the expression by creating aliases
        self._merge_declaration_ctx_with_expr_ctx(declaration_ctx, expr_ctx)
        return declaration_ctx

    def _parse_logical_prog(self, nodes, context: Ctx) -> Ctx:
        assert len(nodes) == 3
        declaration_ctx = nodes[0]
        assert isinstance(declaration_ctx, Z3ComposableDsl.Ctx)
        expr_ctx = nodes[2]
        assert isinstance(expr_ctx, Z3ComposableDsl.Ctx)
        assert (
            declaration_ctx.logical_expr is None or declaration_ctx.arith_expr is None
        )
        assert expr_ctx.logical_expr is not None or expr_ctx.arith_expr is None
        # We need to first merge all the arithmetic sub expressions
        # As they can be used in the logical expression
        while len(declaration_ctx.pending_arith_sub_expressions) > 0:
            # Merge all the pending sub expressions
            sub_expr = declaration_ctx.pending_arith_sub_expressions.pop(0)
            assert isinstance(sub_expr, Z3ComposableDsl.Ctx)
            # Merge the declarations and the expression by creating aliases
            self._merge_sub_expr_without_declaration_ctx(declaration_ctx, sub_expr)
        while len(declaration_ctx.pending_logical_sub_expressions) > 0:
            # Merge all the pending sub expressions
            sub_expr = declaration_ctx.pending_logical_sub_expressions.pop(0)
            assert isinstance(sub_expr, Z3ComposableDsl.Ctx)
            # Merge the declarations and the expression by creating aliases
            self._merge_sub_expr_without_declaration_ctx(declaration_ctx, sub_expr)
        # Merge the declarations and the expression by creating aliases
        self._merge_declaration_ctx_with_expr_ctx(declaration_ctx, expr_ctx)
        return declaration_ctx

    def _parse_logical_expr(self, nodes, context: Ctx) -> Ctx:
        assert len(nodes) == 3 or len(nodes) == 4 or len(nodes) == 6
        new_context = Z3ComposableDsl.Ctx()
        if len(nodes) == 3:
            assert nodes[1] in ["==", "!=", ">", "<", ">=", "<="]
            left = nodes[0]
            right = nodes[2]
            assert isinstance(left, Z3ComposableDsl.Ctx)
            assert isinstance(right, Z3ComposableDsl.Ctx)
            for var_id, var_str_id in left.simple_id_to_var_id.items():
                new_context.simple_id_to_var_id[var_id] = var_str_id
            for var_id, var_str_id in right.simple_id_to_var_id.items():
                new_context.simple_id_to_var_id[var_id] = var_str_id
            left = left.arith_expr
            right = right.arith_expr
            assert left is not None
            assert right is not None
            new_context.logical_expr = (
                lambda ctx: f"({left(ctx)} {nodes[1]} {right(ctx)})"
            )
            return new_context
        elif len(nodes) == 4:
            assert nodes[0] == "Not"
            if isinstance(nodes[2], typing.Tuple):
                var_node = nodes[2]
                assert len(var_node) == 2
                assert isinstance(
                    var_node[0], int
                ), f"Variable should have an integer id: {nodes}"
                assert isinstance(
                    var_node[1], typing.Callable
                ), f"Variable should have a callable id: {nodes}"
                new_context.simple_id_to_var_id[var_node[0]] = var_node[1](context)
                new_context.logical_expr = lambda ctx: f"{nodes[0]}({var_node[1](ctx)})"
            elif isinstance(nodes[2], Z3ComposableDsl.Ctx):
                logical_node = nodes[2]
                assert isinstance(logical_node, Z3ComposableDsl.Ctx)
                new_context.simple_id_to_var_id = logical_node.simple_id_to_var_id
                new_context.logical_expr = (
                    lambda ctx: f"{nodes[0]}({logical_node.logical_expr(ctx)})"
                )
            else:
                raise ValueError(f"Invalid logical expression: {nodes}")
            return new_context
        elif len(nodes) == 6:
            if nodes[0] in ["And", "Or", "Implies"]:
                left = nodes[2]
                right = nodes[4]
                if isinstance(left, typing.Tuple):
                    assert len(left) == 2
                    assert isinstance(
                        left[0], int
                    ), f"Variable should have an integer id: {nodes}"
                    assert isinstance(
                        left[1], typing.Callable
                    ), f"Variable should have a callable id: {nodes}"
                    temp = Z3ComposableDsl.Ctx()
                    temp.simple_id_to_var_id[left[0]] = left[1](context)
                    temp.logical_expr = left[1]
                    left = temp
                if isinstance(right, typing.Tuple):
                    assert len(right) == 2
                    assert isinstance(
                        right[0], int
                    ), f"Variable should have an integer id: {nodes}"
                    assert isinstance(
                        right[1], typing.Callable
                    ), f"Variable should have a callable id: {nodes}"
                    temp = Z3ComposableDsl.Ctx()
                    temp.simple_id_to_var_id[right[0]] = right[1](context)
                    temp.logical_expr = right[1]
                    right = temp
                assert isinstance(left, Z3ComposableDsl.Ctx)
                assert isinstance(right, Z3ComposableDsl.Ctx)
                for var_id, var_str_id in left.simple_id_to_var_id.items():
                    new_context.simple_id_to_var_id[var_id] = var_str_id
                for var_id, var_str_id in right.simple_id_to_var_id.items():
                    new_context.simple_id_to_var_id[var_id] = var_str_id
                left = left.logical_expr
                right = right.logical_expr
                assert (
                    left is not None
                ), f"Compiler Bug: Left expression is None: {nodes}"
                assert (
                    right is not None
                ), f"Compiler Bug: Right expression is None: {nodes}"
                new_context.logical_expr = (
                    lambda ctx: f"{nodes[0]}({left(ctx)}, {right(ctx)})"
                )
                return new_context
            elif nodes[0] in ["ForAll", "Exists"]:
                node = nodes[2]
                if isinstance(node, typing.Tuple):
                    assert len(node) == 2
                    assert isinstance(
                        node[0], int
                    ), f"Variable should have an integer id: {nodes}"
                    assert isinstance(
                        node[1], typing.Callable
                    ), f"Variable should have a callable id: {nodes}"
                    temp_context = Z3ComposableDsl.Ctx()
                    temp_context.simple_id_to_var_id[node[0]] = node[1](context)
                    temp_context.logical_expr = node[1]
                quantified_statement = nodes[4]
                assert isinstance(quantified_statement, Z3ComposableDsl.Ctx)
                for (
                    var_id,
                    var_str_id,
                ) in quantified_statement.simple_id_to_var_id.items():
                    new_context.simple_id_to_var_id[var_id] = var_str_id
                for var_id, var_str_id in temp_context.simple_id_to_var_id.items():
                    new_context.simple_id_to_var_id[var_id] = var_str_id
                nat_valid_cond = "And" if nodes[0] == "Exists" else "Implies"
                new_context.logical_expr = (
                    lambda ctx: f"{nodes[0]}([{temp_context.logical_expr(ctx)}], {nat_valid_cond}({temp_context.logical_expr(ctx)} >= 0, {quantified_statement.logical_expr(ctx)}))"
                )
                return new_context
            assert (
                nodes[0] == "IsMember"
            ), f"Compiler Bug: Invalid logical expression: {nodes}"
            membership_expr = nodes[2]
            assert isinstance(membership_expr, Z3ComposableDsl.Ctx)
            for var_id, var_str_id in membership_expr.simple_id_to_var_id.items():
                new_context.simple_id_to_var_id[var_id] = var_str_id
            new_context.logical_expr = (
                lambda ctx: f"{nodes[0]}({membership_expr.arith_expr(ctx)}, {nodes[4]})"
            )
            return new_context
        else:
            raise ValueError(f"Invalid logical expression: {nodes}")

    def _parse_arith_expr(self, nodes, context) -> Ctx:
        new_context = Z3ComposableDsl.Ctx()
        if len(nodes) == 1:
            node = nodes[0]
            if isinstance(node, typing.Tuple):
                assert len(node) == 2
                assert isinstance(
                    node[0], int
                ), f"Variable should have an integer id: {nodes}"
                assert isinstance(
                    node[1], typing.Callable
                ), f"Variable should have a callable id: {nodes}"
                new_context.simple_id_to_var_id[node[0]] = node[1](context)
                new_context.arith_expr = node[1]
                return new_context
            else:
                assert isinstance(node, int), f"Invalid arithmetic expression: {nodes}"
                new_context.arith_expr = lambda ctx: str(node)
                return new_context
        elif len(nodes) == 3:
            if nodes[0] == "(" and nodes[2] == ")":
                if isinstance(nodes[1], str) or isinstance(nodes[1], int):
                    new_context.arith_expr = lambda ctx: str(nodes[1])
                    return new_context
                else:
                    assert isinstance(nodes[1], Z3ComposableDsl.Ctx)
                    new_context.simple_id_to_var_id = nodes[1].simple_id_to_var_id
                    new_context.arith_expr = lambda ctx: f"({nodes[1].arith_expr(ctx)})"
                    return new_context
            else:
                left = nodes[0]
                right = nodes[2]
                assert isinstance(left, Z3ComposableDsl.Ctx)
                assert isinstance(right, Z3ComposableDsl.Ctx)
                assert nodes[1] in ["+", "-", "*", "/", "%"]
                for var_id, var_str_id in left.simple_id_to_var_id.items():
                    new_context.simple_id_to_var_id[var_id] = var_str_id
                for var_id, var_str_id in right.simple_id_to_var_id.items():
                    new_context.simple_id_to_var_id[var_id] = var_str_id
                left = left.arith_expr
                right = right.arith_expr
                assert (
                    left is not None
                ), f"Compiler Bug: Left expression is None: {nodes}"
                assert (
                    right is not None
                ), f"Compiler Bug: Right expression is None: {nodes}"
                new_context.arith_expr = (
                    lambda ctx: f"({left(ctx)} {nodes[1]} {right(ctx)})"
                )
                return new_context
        else:
            raise ValueError(f"Invalid arithmetic expression: {nodes}")

    def _parse_declaration(self, nodes, context: Ctx) -> Ctx:
        assert len(nodes) >= 3
        assert isinstance(nodes[0], tuple)
        assert len(nodes[0]) == 2
        var_id_int, var_id_map = nodes[0]
        assert isinstance(var_id_int, int)
        assert isinstance(var_id_map, typing.Callable)
        assert var_id_int < context.var_count
        var_id_str = var_id_map(context)
        assert nodes[1] == ":="
        declaration = None
        sub_program = None
        pending_arith_sub_expr = None
        pending_logical_sub_expr = None
        if nodes[2] == "NatVar":
            declaration = lambda ctx: f"{var_id_map(ctx)} = NatVar()"
        elif nodes[3] == "Param":
            declaration = None
        elif nodes[2] == "Exec":
            partial_prog = nodes[4]
            sub_program = partial_prog
            assert isinstance(sub_program, Z3ComposableDsl.Ctx)
            if partial_prog.logical_expr is not None:
                assert (
                    partial_prog.arith_expr is None
                ), f"Compiler Bug: Partial program has both logical and arithmetic expressions."
                declaration = (
                    lambda ctx: f"{var_id_map(ctx)} = {partial_prog.logical_expr(ctx)}"
                )
            elif partial_prog.arith_expr is not None:
                assert (
                    partial_prog.logical_expr is None
                ), f"Compiler Bug: Partial program has both logical and arithmetic expressions."
                declaration = (
                    lambda ctx: f"{var_id_map(ctx)} = {partial_prog.arith_expr(ctx)}"
                )
            else:
                raise ValueError(
                    "User Error: Invalid partial program: both logical and arithmetic expressions are None"
                )
        else:
            # This can be a logical or arithmetic expression
            expr_context = nodes[2]
            assert isinstance(expr_context, Z3ComposableDsl.Ctx)
            arith_expr = expr_context.arith_expr
            logical_expr = expr_context.logical_expr
            if logical_expr is not None:
                assert (
                    arith_expr is None
                ), f"Compiler Bug: Partial program has both logical and arithmetic expressions."
                pending_logical_sub_expr = expr_context
                declaration = lambda ctx: f"{var_id_map(ctx)} = {logical_expr(ctx)}"
            elif arith_expr is not None:
                assert (
                    logical_expr is None
                ), f"Compiler Bug: Partial program has both logical and arithmetic expressions."
                pending_arith_sub_expr = expr_context
                declaration = lambda ctx: f"{var_id_map(ctx)} = {arith_expr(ctx)}"
            else:
                raise Exception(
                    f"Compiler Bug: Both logical and arithmetic expressions are None: {nodes}"
                )
        # Declare the variable in a brand new context
        new_context = Z3ComposableDsl.Ctx()
        new_context.simple_id_to_var_id[var_id_int] = var_id_str
        new_context.var_id_to_simple_id[var_id_str] = var_id_int
        new_context.main_var_id = var_id_str
        # Add the declaration to the new context
        if declaration is None:
            # This is a parameter declaration
            new_context.params.append(var_id_int)
            new_context.params_map[var_id_int] = None
            declaration = (
                lambda ctx: f"{var_id_map(ctx)} = {ctx.simple_id_to_var_id[ctx.params_map[var_id_int]]}"
            )
        new_context.declarations[var_id_str] = declaration
        assert len(new_context.variables_declared_in_order) == 0
        new_context.variables_declared_in_order.append(var_id_int)
        if sub_program is not None:
            # Add the sub program to the new context
            new_context.sub_programs.append(sub_program)
        if pending_arith_sub_expr is not None:
            # Add the pending arithmetic sub expression to the new context
            new_context.pending_arith_sub_expressions.append(pending_arith_sub_expr)
        if pending_logical_sub_expr is not None:
            # Add the pending logical sub expression to the new context
            new_context.pending_logical_sub_expressions.append(pending_logical_sub_expr)
        return new_context

    def flatten_list_recursively(self, lst):
        # Change it to a heap based approach to avoid recursion limit
        if isinstance(lst, list):
            return [
                item
                for sublist in lst
                for item in self.flatten_list_recursively(sublist)
            ]
        else:
            return [lst]

    def _merge_program_context(
        self,
        global_context: Ctx,
        main_context: Ctx,
        sub_context: Ctx,
        allow_promotion: bool = False,
    ):
        vars_in_sub_prog = sub_context.variables_declared_in_order
        sub_prog_var_name_prefix = "promoted_"
        main_prog_var_name_prefix = "local_"
        for simple_var_id in vars_in_sub_prog:
            new_var_str_id = sub_context.simple_id_to_var_id[simple_var_id]
            sub_prog_var = new_var_str_id
            if allow_promotion:
                new_var_str_id = f"{sub_prog_var_name_prefix}{sub_prog_var}"
                # We change the name of the variable in the sub program
                sub_context.simple_id_to_var_id[simple_var_id] = new_var_str_id
                sub_context.var_id_to_simple_id[new_var_str_id] = simple_var_id
                del sub_context.var_id_to_simple_id[sub_prog_var]
                global_context.simple_id_to_var_id[simple_var_id] = new_var_str_id
            else:
                if (
                    sub_prog_var in main_context.var_id_to_simple_id
                    and not allow_promotion
                ):
                    # We cannot change the name of the variable in the sub program
                    # So we promote the conflict variable in the main context/program
                    temp_var_str_id = f"{main_prog_var_name_prefix}{sub_prog_var}"
                    old_simple_var_id = main_context.var_id_to_simple_id[sub_prog_var]
                    main_context.simple_id_to_var_id[old_simple_var_id] = (
                        temp_var_str_id
                    )
                    main_context.var_id_to_simple_id[temp_var_str_id] = (
                        old_simple_var_id
                    )
                    del main_context.var_id_to_simple_id[sub_prog_var]
                    main_context.declarations[temp_var_str_id] = main_context.declarations[sub_prog_var]
                    del main_context.declarations[sub_prog_var]
                    global_context.simple_id_to_var_id[old_simple_var_id] = temp_var_str_id
            main_context.simple_id_to_var_id[simple_var_id] = new_var_str_id
            main_context.var_id_to_simple_id[new_var_str_id] = simple_var_id
            main_context.var_count += 1
            main_context.declarations[new_var_str_id] = sub_context.declarations[
                sub_prog_var
            ]
            # NOTE: context is the global context, and does not have the variable in scope
            # Hence no one to one mapping from var_id_str to simple_id
            global_context.simple_id_to_var_id[simple_var_id] = new_var_str_id
        for alias_id, simple_var_id in sub_context.simple_id_to_alias_id.items():
            # Add corresponding alias in the main context
            assert (
                simple_var_id in main_context.simple_id_to_var_id
            ), f"Compiler Bug: Alias '{alias_id}' is not in the main context."
            main_context.simple_id_to_alias_id[alias_id] = simple_var_id
            new_var_str_id = main_context.simple_id_to_var_id[simple_var_id]
            # Take this opportunity to change the name of the variable in the global context too
            global_context.simple_id_to_var_id[alias_id] = new_var_str_id
        for variable_declared_in_order in sub_context.variables_declared_in_order:
            # Add the variable to the main context
            assert (
                variable_declared_in_order
                not in main_context.variables_declared_in_order
            ), f"Compiler Bug: Variable '{variable_declared_in_order}' is already declared in the main context."
            main_context.variables_declared_in_order.append(variable_declared_in_order)

    def _parse_declarations(self, nodes, context: Ctx) -> typing.Optional[Ctx]:
        if isinstance(nodes, list):
            if len(nodes) == 0:
                return None
            # Flatten the list of nodes
            nodes = self.flatten_list_recursively(nodes)
            assert all(
                isinstance(node, Z3ComposableDsl.Ctx) or node is None for node in nodes
            )
            # Filter out None values
            nodes = [node for node in nodes if node is not None]
            var_in_scope_set = set()
            vars_in_scope: typing.List[str] = []
            var_contexts: typing.List[Z3ComposableDsl.Ctx] = []
            var_declarations: typing.List[
                typing.Callable[[Z3ComposableDsl.Ctx], str]
            ] = []
            for node in nodes:
                new_context = node
                assert isinstance(new_context, Z3ComposableDsl.Ctx)
                assert new_context.main_var_id is not None
                declaration = new_context.declarations[new_context.main_var_id]
                var_str_id = new_context.main_var_id
                assert isinstance(var_str_id, str), f"Invalid variable id: {var_str_id}"
                if var_str_id in var_in_scope_set:
                    raise ValueError(
                        f"User Error: Variable '{var_str_id}' is declared multiple times in the same local scope."
                    )
                else:
                    var_in_scope_set.add(var_str_id)
                    vars_in_scope.append(var_str_id)
                    var_contexts.append(new_context)
                    var_declarations.append(declaration)
            merged_context = Z3ComposableDsl.Ctx()
            for var_str_id, prog_context, declaration in zip(
                vars_in_scope, var_contexts, var_declarations
            ):
                for sub_prog_context in prog_context.sub_programs:
                    # Merge any subprograms needed for the current declaration first
                    self._merge_program_context(
                        context, merged_context, sub_prog_context, allow_promotion=True
                    )
                    self._merge_declaration_ctx_with_expr_ctx(
                        merged_context, sub_prog_context
                    )
                self._merge_program_context(
                    context, merged_context, prog_context
                )  # Merge the declarations first
                self._merge_declaration_ctx_with_expr_ctx(
                    merged_context, prog_context
                )  # Merge the declarations and the expression by creating aliases
                merged_context.pending_arith_sub_expressions.extend(
                    prog_context.pending_arith_sub_expressions
                )
                merged_context.pending_logical_sub_expressions.extend(
                    prog_context.pending_logical_sub_expressions
                )
                merged_context.main_var_id = var_str_id
            return merged_context
        else:
            raise ValueError(f"Invalid declaration: {nodes}")

    def _parse_var(
        self, nodes, context: Ctx
    ) -> typing.Tuple[int, typing.Callable[[Ctx], str]]:
        assert isinstance(nodes, str)
        var_name = str(nodes).strip()
        if var_name in Z3ComposableDsl.keywords:
            raise ValueError(
                f"User Error: A keyword '{var_name}' cannot be used as a variable name."
            )
        if any(var_name.startswith(keyword) for keyword in Z3ComposableDsl.keywords):
            raise ValueError(
                f"User Error: A variable name '{var_name}' cannot have a keyword as a prefix."
            )
        simple_id = context.var_count
        context.var_count += 1
        context.simple_id_to_var_id[simple_id] = var_name

        def _get_var_id(ctx: Z3ComposableDsl.Ctx) -> str:
            var_str_id = ctx.simple_id_to_var_id.get(simple_id, None)
            if var_str_id is None:
                alias_id = ctx.simple_id_to_alias_id.get(simple_id, None)
                assert (
                    alias_id is not None
                ), f"Compiler Bug: Variable '{var_name}' is was not registered in the local context."
                # Match the variable name with the aliases in the local context
                var_str_id = ctx.simple_id_to_var_id.get(alias_id, None)
                # A Bug Prone Hack if uncommented
                # if var_str_id is None:
                #     # Now just load the var_str_id from the global context
                #     var_str_id = context.simple_id_to_var_id.get(alias_id, None)
                assert (
                    var_str_id is not None
                ), f"Compiler Bug: Variable '{var_name}' is was registered in the local context."
                f"But alias was not assigned."
            return var_str_id

        # Check if the variable is already declared
        return simple_id, _get_var_id

    def get_action(self, inp=None):
        ctx = Z3ComposableDsl.Ctx()
        actions = {
            "Prog": lambda _, nodes: self._parse_prog(nodes, ctx),
            "LogicalProg": lambda _, nodes: self._parse_logical_prog(nodes, ctx),
            "ArithProg": lambda _, nodes: self._parse_arith_prog(nodes, ctx),
            "Declarations": lambda _, nodes: self._parse_declarations(nodes, ctx),
            "Declaration": lambda _, nodes: self._parse_declaration(nodes, ctx),
            "LogicalExpr": lambda _, nodes: self._parse_logical_expr(nodes, ctx),
            "ArithExpr": lambda _, nodes: self._parse_arith_expr(nodes, ctx),
            "var": lambda _, nodes: self._parse_var(nodes, ctx),
            "num": lambda _, nodes: int(nodes),
        }
        return actions

    def interpret_result(self, result) -> Z3TranslationGrammarResult:
        assert isinstance(
            result, Z3ComposableDsl.Ctx
        ), f"Result must be a Z3.Ctx object. Got {type(result)}"
        simple_var_id_in_order = [
            var_id for var_id in result.variables_declared_in_order
        ]
        ordered_declaration = [
            result.simple_id_to_var_id[var_id] for var_id in simple_var_id_in_order
        ]
        ordered_declaration_values = [
            result.declarations[var](result) for var in ordered_declaration
        ]
        constraints = "\n".join(ordered_declaration_values)
        expr = (
            result.logical_expr
            if result.logical_expr is not None
            else result.arith_expr
        )
        if expr is not None:
            expr = expr(result)
        else:
            raise Exception(
                "Compiler Bug: Something wrong in the grammar. No expression found."
            )
        if result.logical_expr is not None:
            z3_dsl = Z3PythonicTranslation()
            proof, counter_example = z3_dsl.prove(constraints, expr)
            return Z3ComposableDsl.Z3TranslationGrammarResult(
                is_logical=True,
                is_arith=False,
                proved=proof,
                counter_example=counter_example,
                timed_out=False,
                constraints=constraints,
                expr=expr,
            )
        elif result.arith_expr is not None:
            z3_dsl = Z3PythonicTranslation()
            return Z3ComposableDsl.Z3TranslationGrammarResult(
                is_logical=False,
                is_arith=True,
                proved=False,
                counter_example=None,
                timed_out=False,
                constraints=constraints,
                expr=expr,
            )
        else:
            raise Exception(
                "Compiler Bug: Something wrong in the grammar. No expression found."
            )

    def run(self, code: str) -> Z3TranslationGrammarResult:
        return super().run(code, None)




