import unittest
import time
from frame.tools.z3_dsl import Z3Program

class TestZ3Program(unittest.TestCase):
    def tryProve(self, code):
        print('='*50)
        print("Code:")        
        print(code)
        print('-'*50)
        start_time = time.time()
        program = Z3Program.from_code(code)
        run_result = program.run()
        end_time = time.time()
        print("Execution Time:", end_time - start_time)
        print("Proved:", run_result.proved)
        print("Counter Example:", run_result.counter_example)
        print("Timed Out:", run_result.timed_out)
        print(f"SMT2 Code:\n{run_result.smt2}")
        print("=" * 50)
        return run_result

    def assertProved(self, code):
        run_result = self.tryProve(code)
        self.assertTrue(run_result.proved)

    def assertNotProved(self, code):
        run_result = self.tryProve(code)
        self.assertFalse(run_result.proved)
        self.assertIsNotNone(run_result.counter_example)

    def test1(self):
        # Test case 1: Simple expression
        code = """
            params 0;
            bounded params 1;
            ReturnExpr None;
            ReturnPred Exists([b_0], 2 + b_0 == 3);
        """
        self.assertProved(code)


    def test2(self):
        # Test case 2: Function with parameters
        code = """
            params 0;
            bounded params 0;
            f_0 := Func(
                params 0;
                bounded params 0;
                ReturnExpr 1;
                ReturnPred None;
            );
            f_1 := Func(
                params 0;
                bounded params 0;
                ReturnExpr 2;
                ReturnPred None;
            );
            ReturnExpr None;
            ReturnPred f_0() + f_1() == 3;
        """
        self.assertProved(code)
    
    def test3(self):
        code = """
            params 0;
            bounded params 2;
            p_0 := Pred(
                params 1;
                bounded params 0;
                ReturnExpr None;
                ReturnPred x_0 >= 0;
            );
            ReturnExpr None;
            ReturnPred ForAll([b_0, b_1],
                Implies(
                p_0(x_0=b_0), 
                p_0(x_0=b_1), 
                p_0(x_0=b_0 + b_1)) 
            );
        """
        self.assertProved(code)
    
    def test4(self):
        # Test case 4: 6 is even
        code = """
            params 0;
            bounded params 1;
            f_0 := Func(
                params 1;
                bounded params 0;
                ReturnExpr 2*x_0;
                ReturnPred None;
            );
            f_1 := Func(
                params 0;
                bounded params 0;
                ReturnExpr 6;
                ReturnPred None;
            );
            ReturnExpr None;
            ReturnPred Exists([b_0],
                f_0(x_0=b_0) == f_1()
            );
        """
        self.assertProved(code)
    
    def test5(self):
        # Test case 5: 7 is prime
        code = """
            params 0;
            bounded params 1;
            p_0 := Pred(
                params 1;
                bounded params 1;
                p_0 := Pred(
                    params 2;
                    bounded params 0;
                    ReturnExpr None;
                    ReturnPred x_0 % x_1 == 0;
                );
                ReturnExpr None;
                ReturnPred Not(
                    Exists([b_0], 
                        And(
                            1 < b_0, 
                            b_0 < x_0, 
                            p_0(x_1=b_0)
                        )
                    )
                );
            );
            ReturnExpr None;
            ReturnPred p_0(x_0=7);
        """
        self.assertProved(code)
    
    def test6(self):
        # checking gcd
        divides_code = Z3Program.from_code("""
            params 2;
            bounded params 0;
            ReturnExpr None;
            ReturnPred x_0 % x_1 == 0;
        """)
        less_than_eq_code = Z3Program.from_code("""
            params 2;
            bounded params 0;
            ReturnExpr None;
            ReturnPred x_0 <= x_1;
        """)
        # gcd(x_0, x_1) = x_2
        is_gcd = Z3Program.from_code(f"""
            params {divides_code.params + 1};
            bounded params {divides_code.params - 1};
            p_0 := Pred(
            {divides_code.dsl()}
            );
            p_1 := Pred(
            {less_than_eq_code.dsl()}
            );
            ReturnExpr None;
            ReturnPred ForAll([b_0],
                Implies(
                    And(
                        p_0(x_1=b_0), 
                        p_0(x_0=x_1, x_1=b_0)
                    ),
                    p_1(x_0=b_0, x_1=x_2)
                )
            );
        """)
        # gcd of 70 and 175 is 35
        gcd_of = f"""
            params 0;
            bounded params 0;
            p_0 := Pred(
            {is_gcd.dsl()}
            );
            ReturnExpr None;
            ReturnPred p_0(x_0=70, x_1=175, x_2=35);
        """
        self.assertProved(gcd_of)
    
    def test7(self):
        code = """
        params 0;
        bounded params 2;

        p_0 := Pred(
            params 1;
            bounded params 0;

            p_0 := Pred(
                params 1;
                bounded params 1;

                f_0 := Func(
                    params 1;
                    bounded params 0;
                    ReturnExpr x_0 + 1;
                    ReturnPred None;
                );

                ReturnExpr None;
                ReturnPred Exists([b_0], f_0(x_0 = b_0) == x_0);
            );

            ReturnExpr None;
            ReturnPred Not(p_0(x_0 = x_0));
        );

        p_1 := Pred(
            params 2;
            bounded params 1;

            f_0 := Func(
                params 2;
                bounded params 0;
                ReturnExpr x_0 + x_1;
                ReturnPred None;
            );

            ReturnExpr None;
            ReturnPred Exists([b_0], f_0(x_0 = x_0, x_1 = b_0) == x_1);
        );

        ReturnExpr None;
        ReturnPred ForAll([b_0, b_1], Implies(p_0(x_0 = b_0), p_1(x_0 = b_1, x_1 = b_0)));
        """
        self.assertNotProved(code)

if __name__ == "__main__":
    unittest.main()