import unittest
from frame.tools.z3_dsl import Z3ComposableDsl

class TestZ3Compiler(unittest.TestCase):
    def setUp(self):
        self.grammar = Z3ComposableDsl()

    def compile(self, program):
        # Compile the program using the Z3ComposableDsl grammar
        return self.grammar.compile(program)
    
    def assertCompiles(self, program):
        # Assert that the program compiles without errors
        try:
            result = self.compile(program)
            self.assertIsNotNone(result)
            print(f"Compilation successful: \n{result}")
        except Exception as e:
            self.fail(f"Compilation failed: {e}")

    def test1(self):
        # Test case 1: Simple expression
        program = """
            params 1;
            bounded params 1;
            ReturnExpr None;
            ReturnPred Exists([b_0], 2 + b_0 == x_0);
        """
        self.assertCompiles(program)
    
    def test2(self):
        # Test case 2: Function with parameters
        program = """
            params 1;
            bounded params 1;
            f_1 := Func(
                params 1;
                bounded params 0;
                ReturnExpr 2*x_0;
                ReturnPred True;
            );
            ReturnExpr None;
            ReturnPred Exists([b_0], b_0 == f_1(x_0=2));
        """
        self.assertCompiles(program)
    
    def test3(self):
        # Test case 3: Nested function calls
        program = """
            params 2;
            bounded params 2;
            f_1 := Func(
                params 1;
                bounded params 0;
                ReturnExpr 2*x_0;
                ReturnPred True;
            );
            f_2 := Func(
                params 0;
                bounded params 0;
                ReturnExpr 4;
                ReturnPred True;
            );
            ReturnExpr None;
            ReturnPred Exists([b_0, b_1], b_0 + b_1 == f_1(x_0=f_2()));
        """
        self.assertCompiles(program)

    def test4(self):
        # Test case 4: Complex expression with multiple variables
        program = """
            params 3;
            bounded params 3;
            f_1 := Func(
                params 1;
                bounded params 0;
                ReturnExpr x_0 + 1;
                ReturnPred True;
            );
            f_2 := Func(
                params 1;
                bounded params 0;
                ReturnExpr x_0 * 2;
                ReturnPred True;
            );
            ReturnExpr None;
            ReturnPred ForAll([b_0, b_1], 
              Implies(b_0 + b_1 == f_1(x_0=f_2(x_0=3)), f_1(x_0=f_2(x_0=3^1)) == b_1 + b_0));
        """
        self.assertCompiles(program)

    def test5(self):
        # Test case 5: Function with multiple parameters
        program = """
            params 2;
            bounded params 2;
            f_1 := Func(
                params 2;
                bounded params 0;
                ReturnExpr x_0 + x_1;
                ReturnPred True;
            );
            f_0 := Func(
                params 0;
                bounded params 0;
                ReturnExpr 1;
                ReturnPred True;
            );
            ReturnExpr None;
            ReturnPred Exists([b_0, b_1], b_0 + b_1 == f_0() + f_1(x_0=f_0() + f_1()));
        """
        self.assertCompiles(program)


if __name__ == "__main__":
    unittest.main()