import unittest
from frame.tools.z3_helper import Z3ComposableDsl


class TestZ3TranslationGrammar(unittest.TestCase):
    def setUp(self):
        self.grammar = Z3ComposableDsl()

    def run_program(self, program, is_logical=True, proved=True):
        # Test the translation of a simple expression
        run_result = self.grammar.run(program)
        contraints = run_result.constraints
        expr = run_result.expr
        print(f"Program:")
        print("-" * 50)
        print(program)
        print(f"Z3 Low level:")
        print("-" * 50)
        print(f"{contraints}\n{expr}")
        print(
            f"Expression Type: {'Logical' if run_result.is_logical else 'Arithmetic'}"
        )
        print("Proof:" + f"{'found' if run_result.proved else 'failed'}")
        if run_result.counter_example is not None:
            self.assertFalse(run_result.proved)
            print("Counter Example:")
            print(run_result.counter_example)
        print("========================================")
        self.assertEqual(run_result.is_logical, is_logical)
        self.assertEqual(run_result.proved, proved)

    def test_simple_expression(self):
        program = """
x := NatVar();
y := NatVar();
return x + y >= 0
"""
        self.run_program(program, is_logical=True, proved=True)

    def test_logical_expression(self):
        program = """
x := NatVar();
y := Exec(
    x := NatVar();
    return 2*x
);
return ((2*x + y) % 2) == 0
"""
        self.run_program(program, is_logical=True, proved=True)

    def test_complex_expression1(self):
        program = """
x := NatVar();
y := NatVar();
return x * y"""
        self.run_program(program, is_logical=False, proved=False)

    def test_complex_expression2(self):
        program = """
x := Exec(
    y := Exec(
        x := NatVar();
        return x + 1);
   return y + 1
);
return x"""
        self.run_program(program, is_logical=False, proved=False)

    def test_complex_expression3(self):
        program = """
y := Exec(
    x := NatVar();
    return x + 1);
x := Exec(
    y := Exec(
        x := NatVar();
        return x + 1);
   return y + 1
);
return x + y"""
        self.run_program(program, is_logical=False, proved=False)

    def test_complex_expression4(self):
        program = """
x := NatVar();
y := NatVar();
z := NatVar();
x_plus_y := x + y;
return Implies(z == x_plus_y, Implies(x_plus_y + z == 2*z, z - x - y == 0))"""
        self.run_program(program, is_logical=True, proved=True)

    def test_complex_expression5(self):
        program = """
x := NatVar();
y := NatVar();
z := NatVar();
x_plus_y := x + y;
return Implies(z == x_plus_y, x_plus_y + z == 3*z)"""
        self.run_program(program, is_logical=True, proved=False)

    def test_exists_prime(self):
        primes = [
            (2, True),
            (31, True),
            (51, False),
            (53, True),
            (87, False),
            (71, True),
            (85, False),
            (90, False),
            (84, False),
        ]
        for prime, is_prime in primes:
            program = f"""
n := {prime};
x := NatVar();
return Not(Exists(x, And(And(x < n, x > 1), n % x == 0)))
"""
            self.run_program(program, is_logical=True, proved=is_prime)

    def test_forall(self):
        program = """
x := NatVar();
y := NatVar();
return ForAll(x, Implies(x > 0, x + y > y))
"""
        self.run_program(program, is_logical=True, proved=True)

    def test_is_member1(self):
        program = """
x := NatVar();
y := NatVar();
return IsMember(x + y, NatSet)
"""
        self.run_program(program, is_logical=True, proved=True)

    def test_is_member2(self):
        program = """
x := NatVar();
y := NatVar();
return IsMember(x - y, NatSet)
"""
        self.run_program(program, is_logical=True, proved=False)

    def test_exists_addition(self):
        program = """
x := 7;
y := NatVar();
return Not(Exists(y, x + y == 0))
"""
        self.run_program(program, is_logical=True, proved=True)

    def test_bug1(self):
        program = """
n := Exec(
    n := Exec(
        n := 1;
        return n);
    x := NatVar();
    return Exists(x, n  == 2 * x)
);
return Not(n)
"""
        self.run_program(program, is_logical=True, proved=True)

    def test_bug2(self):
        program = """
left := Exec(n := 1;
return n);
right := Exec(n := Exec(n := 0;
return n);
return n + 1);
return left == right"""
        self.run_program(program, is_logical=True, proved=True)
    
    def test_bug3(self):
        program = """
result1 := Exec(n := 2;
return n > 1);
result2 := Exec(n := 1;
return n < 2);
return And(result1, result2)"""
        self.run_program(program, is_logical=True, proved=True)
    
    def test_bug4(self):
        program = """
x := NatVar();
return ForAll(x, x > 1)"""
        self.run_program(program, is_logical=True, proved=False)

if __name__ == '__main__':
    t = TestZ3TranslationGrammar()
    t.setUp()
    t.test_bug4()
