import z3
import ast
from common.task import GsmPlusTask
from common.verdict import SolverReply
from evaluators.solvers.gsm_plus.z3_parsing import (
    parse_private_constraint,
    parse_public_constraint,
    z3_parse,
    z3_var_from,
)


def solve_gsm(*, task: GsmPlusTask, ctx: z3.Context, timeout_ms: int) -> SolverReply:
    llm_answer = task.llm_solution
    if llm_answer == "" or llm_answer is None:
        return SolverReply(verdict="unknown", error_message="Empty LLM answer")

    solver = z3.Solver(ctx=ctx)
    z3_vars = {}
    for var_name, var_type in task.variable_types.items():
        z3_var = z3_var_from(name=var_name, orig_type=var_type, ctx=ctx)
        z3_vars[var_name] = z3_var

        if var_type in ("int", "float"):
            solver.add(z3_var > 0)

    for public_constraint in task.constraints:
        # TODO?: Replace the hack (the entire code snippet below)
        # annotation vs 'parse-able' indicator in JSON AND the direct == check
        if (
            "in" in public_constraint.split(" ") or "=" in public_constraint
        ) and "alphabets" not in public_constraint:
            if (
                public_constraint
                != "unit2 = 'meters' if unit1 == 'meters' else 'yards'"
            ):
                constraint = parse_public_constraint(public_constraint, task, ctx)
                if constraint is not None:
                    solver.add(constraint)
            else:
                unit1_var = z3_var_from("unit1", "str", ctx)
                unit2_var = z3_var_from("unit2", "str", ctx)
                constraint = unit2_var == z3.If(
                    unit1_var == z3.StringVal("meters", ctx),
                    z3.StringVal("meters", ctx),
                    z3.StringVal("yards", ctx),
                )
                solver.add(constraint)
        ############# (end replacement)

    for private_constraint in task.private_constraints:
        constraint = parse_private_constraint(private_constraint, task, ctx)
        if constraint is not None:
            solver.add(constraint)

    answer_tree = ast.parse(task.answer)
    assert len(answer_tree.body) == 1, "Unhandled module len > 1"
    base_answer = z3_parse(answer_tree.body[0], task, ctx)

    try:
        generated_tree = ast.parse(llm_answer)
    except Exception as e:
        return SolverReply(verdict="failure", error_message=f"Parsing error: {e}")

    # If this happens, the LLM generated something very wrong, not matching the
    # expected output format:
    # assert len(generated_tree.body) == 1, 'Unhandled module len > 1'
    try:
        symbolic_generated_response = z3_parse(generated_tree.body[0], task, ctx)
    except Exception as e:
        return SolverReply(verdict="failure", error_message=f"Parsing error: {e}")

    solver.push()

    try:
        solver.set("timeout", timeout_ms)
        solver.add(base_answer != symbolic_generated_response)
        success = solver.check() == z3.unsat

        if success:
            return SolverReply(verdict="success")
        else:
            counterexample = {}
            model = solver.model()
            for d in model.decls():
                counterexample[d.name()] = model[d]
            return SolverReply(verdict="failure", counterexample=str(counterexample))
    except Exception as e:
        return SolverReply(
            verdict="failure", error_message=f"Error during solving: {e}"
        )
