import z3
from common.task import FolTask
from common.verdict import SolverReply
from evaluators.adapters.fol.to_z3 import parse_to_z3, Signature


def solve_fol(*, task: FolTask, ctx: z3.Context, timeout_ms: int) -> SolverReply:
    llm_answer = task.llm_solution
    if llm_answer == "" or llm_answer is None:
        return SolverReply(verdict="unknown", error_message="Empty LLM answer")

    signature = Signature(ctx=ctx)
    parsed_base_answer = parse_to_z3(task.answer, signature)
    try:
        parsed_llm_answer = parse_to_z3(llm_answer, signature)
    except Exception as e:
        # Happens when there's a parsing error,
        # or some type error like an arity mismatch, etc.
        return SolverReply(verdict="failure", error_message=f"Parsing error: {e}")

    try:
        solver = z3.Solver(ctx=ctx)
        solver.set(timeout=timeout_ms)

        are_not_equivalent = z3.Xor(parsed_base_answer, parsed_llm_answer)

        solver.push()
        solver.add(are_not_equivalent)
        r = solver.check()
        solver.pop()
        if r == z3.unsat:
            return SolverReply(verdict="success")
        if r == z3.sat:
            counterexample = {}
            model = solver.model()
            for d in model.decls():
                counterexample[d.name()] = model[d]
            return SolverReply(verdict="failure", counterexample=str(counterexample))
        return SolverReply(verdict="unknown")
    except Exception as e:
        return SolverReply(verdict="failure", error_message=f"Z3 error: {e}")
