import typing
from z3 import (Solver, ForAll, And, Implies, Const, SetSort, IntSort, Int, IsMember, parse_smt2_string, z3)

class Z3DslRuntime(object):
    """
    This class is a DSL for writing Z3 constraints in a more Pythonic way.
    """
    def __init__(self):
        self._z3_pythonic_dsl_solver = Solver()
        self._z3_pythonic_dsl_variable_map = {}
        self._z3_pythonic_dsl_nat_var_prefix = "z3_dsl_nat_var_"
        self._z3_pythonic_dsl_nat_var_counter = 0
        self._z3_pythonic_dsl_nat_set = SetSort(IntSort())
        self._NatSet = Const("z3_pythonic_dsl_nat_set", self._z3_pythonic_dsl_nat_set)
        self._z3_pythonic_dsl_x = Int(self._z3_pythonic_dsl_nat_var_prefix + "_x")
        self._z3_pythonic_dsl_solver.add(
            ForAll(
                [self._z3_pythonic_dsl_x],
                And(
                    Implies(
                        IsMember(self._z3_pythonic_dsl_x, self._NatSet),
                        self._z3_pythonic_dsl_x >= 0,
                    ),
                    Implies(
                        self._z3_pythonic_dsl_x >= 0,
                        IsMember(self._z3_pythonic_dsl_x, self._NatSet),
                    ),
                ),
            )
        )
        self._z3_pythonic_dsl_solver.push()

    def NatVar(self):
        n = Int(
            self._z3_pythonic_dsl_nat_var_prefix
            + str(self._z3_pythonic_dsl_nat_var_counter)
        )
        self._z3_pythonic_dsl_solver.add(n >= 0)
        self._z3_pythonic_dsl_solver.add(IsMember(n, self._NatSet))
        self._z3_pythonic_dsl_nat_var_counter += 1
        return n

    def Nat(self, name):
        if name in self._z3_pythonic_dsl_variable_map:
            return self._z3_pythonic_dsl_variable_map[name]
        n = Int(name)
        self._z3_pythonic_dsl_solver.add(n >= 0)
        self._z3_pythonic_dsl_solver.add(IsMember(n, self._NatSet))
        self._z3_pythonic_dsl_variable_map[name] = n
        return n

    def prove(
        self, 
        smtlib_string: str, 
        timeout=None
    ) -> typing.Tuple[bool, typing.Optional[str]]:
        exec_locals = {
            "NatVar": self.NatVar,
            "NatSet": self._NatSet,
            "z3_pythonic_dsl_solver": self._z3_pythonic_dsl_solver,
        }
        self._z3_pythonic_dsl_solver.push()
        final_theorem = "\n(assert (not return_pred))"
        smtlib_string += final_theorem
        smt_exprs = parse_smt2_string(smtlib_string) #, decls=exec_locals)
        for smt_expr in smt_exprs:
            self._z3_pythonic_dsl_solver.add(smt_expr)
        if timeout is not None:
            self._z3_pythonic_dsl_solver.set(timeout=timeout)
        z3_pythonic_dsl_result = self._z3_pythonic_dsl_solver.check()
        proof_found = z3_pythonic_dsl_result == z3.unsat
        timed_out = z3_pythonic_dsl_result == z3.unknown
        counterexample_found = z3_pythonic_dsl_result == z3.sat
        assert (proof_found or timed_out or counterexample_found), f"Z3 returned an unexpected result: {z3_pythonic_dsl_result}"
        counter_example = None
        if not proof_found and not timed_out:
            counter_example = self._z3_pythonic_dsl_solver.model()
        else:
            counter_example = None
        self._z3_pythonic_dsl_solver.pop()
        self._z3_pythonic_dsl_variable_map.clear()
        self._z3_pythonic_dsl_nat_var_counter = 0
        return proof_found, counter_example