from typing import Literal
from common.task import Task
from common.verdict import SolverReply
import z3

from common.task import GsmPlusTask, SqlTask, CoqTask, RegexTask, FolTask

from evaluators.solvers.gsm_plus.gsm_solver import solve_gsm
from evaluators.solvers.sql.sql_solver import solve_sql
from evaluators.solvers.coq.coq_solver import solve_coq
from evaluators.solvers.regex.regex_solver import solve_regex
from evaluators.solvers.fol.fol_solver import solve_fol
from llm.llm import LLM
from llm.util import normalize_response


def evaluate_with_equivalence(task: Task, model: LLM, *, timeout_ms: int, do_log: bool = True) -> SolverReply:
    """Evaluates the task by checking equivalence of the LLM solution with the ground truth.
    
    model parameter is only taken as logging cache key.
    """
    if do_log:
        if cached_result := model.read_for_verifier(task.id_in_domain, task.domain):
            return cached_result

    if isinstance(task, GsmPlusTask):
        result = solve_gsm(task=task, ctx=z3.Context(), timeout_ms=timeout_ms)
    elif isinstance(task, SqlTask):
        result = solve_sql(task=task, timeout_ms=timeout_ms)
    elif isinstance(task, CoqTask):
        result = solve_coq(task=task, timeout_ms=timeout_ms)
    elif isinstance(task, RegexTask):
        result = solve_regex(task=task, timeout_ms=timeout_ms)
    elif isinstance(task, FolTask):
        result = solve_fol(task=task, ctx=z3.Context(), timeout_ms=timeout_ms)
    else:
        # Should not happen
        raise ValueError(f"Unknown task type: {type(task)}")
    
    if do_log:
        model.log_for_verifier(task.id_in_domain, task.domain, result)
    return result


def evaluate_with_nl_judge(
    task: Task,
    model: LLM,
    *,
    correct_treshold: int = 2,
    use_base_answer: bool = False
) -> SolverReply:
    """Uses the LLM itself as a judge to determine if the solution is correct via
    asking the LLM for its confidence on the answer.

    Answers only >= correct_treshold are considered correct.
    By default, if the LLM responds with an integer >= 2, the solution is considered correct.
    The LLM always responds with an integer from 0 to 3 (both inclusive).
    """
    if use_base_answer:
        answer = task.answer
        if isinstance(task, GsmPlusTask):
            # int() should not affect the answer
            answer = answer.replace('int', '')
        prompt = f"Task: {task.natural_language}\n\nProposed Solution: {answer}"
    else:
        prompt = f"Task: {task.natural_language}\n\nProposed Solution: {task.llm_solution}"
    try:
        llm_response = model.call(
            prompt=prompt,
            problem_id=task.id_in_domain,
            problem_domain=task.domain,
            judge_type="nl-base" if use_base_answer else "nl",
        )
    except Exception as e:
        print('Failed:', e)
        return SolverReply(
            verdict="unknown",
            error_message=f"LLM judge call failed: {e}",
        )
    try:
        llm_judge_answer = int(normalize_response(llm_response).lower())
        if llm_judge_answer >= correct_treshold:
            return SolverReply(verdict="success", judge_score=llm_judge_answer)
        else:
            return SolverReply(
                verdict="failure",
                error_message=f"LLM judge returned {llm_judge_answer}",
                judge_score=llm_judge_answer
            )
    except ValueError:
        return SolverReply(
            verdict="unknown",
            error_message=f"LLM judge returned non-integer response: {llm_response}",
        )


def evaluate_with_equivalence_judge(
    task: Task,
    model: LLM,
    *,
    correct_treshold: int = 2,
) -> SolverReply:
    """
    Uses the LLM itself as a judge to determine if the solution is correct via asking
    whether two specifications are semantically equivalent.

    Uses majority-vote approach as described in the paper.
    Answers only >= correct_treshold are considered correct.
    By default, if the LLM responds with an integer >= 2, the solution is considered correct.
    The LLM always responds with an integer from 0 to 3 (both inclusive).
    """
    answer_a = _evaluate_with_equivalence_judge(
        task, model, order="answer_first", correct_treshold=correct_treshold
    )
    answer_b = _evaluate_with_equivalence_judge(
        task, model, order="llm_solution_first", correct_treshold=correct_treshold
    )
    # If either is unknown, return those
    if answer_a.verdict == "unknown":
        return answer_a
    elif answer_b.verdict == 'unknown':
        return answer_b
    # Else, majority vote:
    elif answer_a.verdict == "success" and answer_b.verdict == "success":
        return SolverReply(verdict="success")
    elif answer_a.verdict == "failure" and answer_b.verdict == "failure":
        return SolverReply(verdict="failure")
    # Different cases:
    else:
        assert answer_a.judge_score is not None
        assert answer_b.judge_score is not None
        score_total = answer_a.judge_score + answer_b.judge_score
        if score_total >= correct_treshold * 2:
            return SolverReply(verdict="success")
        return SolverReply(verdict="failure")


def _evaluate_with_equivalence_judge(
    task: Task,
    model: LLM,
    order: Literal["answer_first", "llm_solution_first"],
    *,
    correct_treshold: int = 2,
) -> SolverReply:
    if order == "answer_first":
        prompt = f"A: ```{task.answer}```\n\nB: ```{task.llm_solution}```"
    elif order == "llm_solution_first":
        prompt = f"A: ```{task.llm_solution}```\n\nB: ```{task.answer}```"
    else:
        raise ValueError("Unknown judging order")

    prompt += "\n"
    if isinstance(task, SqlTask):
        prompt += task.nl_constraints_only
    elif isinstance(task, FolTask):
        prompt += task.nl_symbols_only
    elif isinstance(task, GsmPlusTask):
        gsm_answer = task.answer.replace('int', '')
        if order == "answer_first":
            prompt = f"A: ```{gsm_answer}```\n\nB: ```{task.llm_solution}```"
        elif order == "llm_solution_first":
            prompt = f"A: ```{task.llm_solution}```\n\nB: ```{gsm_answer}```"

        prompt += '\n' + task.nl_constraints_only
    elif isinstance(task, CoqTask):
        prompt += task.nl_context_only
    # Regex is intentionally omitted, because there isn't
    # really a context to provide for it

    try:
        llm_response = model.call(
            prompt=prompt,
            problem_id=task.id_in_domain,
            problem_domain=task.domain,
            judge_type="equivalence",
            judge_order=order,
        )
    except Exception as e:
        print('Failed (should not happen as we have retries):', e)
        return SolverReply(
            verdict="unknown",
            error_message=f"LLM judge call failed: {e}",
        )
    try:
        llm_judge_answer = int(normalize_response(llm_response).lower())
        if llm_judge_answer >= correct_treshold:
            return SolverReply(verdict="success", judge_score=llm_judge_answer)
        else:
            return SolverReply(
                verdict="failure",
                error_message=f"LLM judge returned {llm_judge_answer}",
                judge_score=llm_judge_answer
            )
    except ValueError:
        return SolverReply(
            verdict="unknown",
            error_message=f"LLM judge returned non-integer response: {llm_response}",
        )
