import os
from typing import Literal
from common.task import Domain, Task
from common.verdict import SolverReply
from evaluators.results import ResultCounter, OverallResults
from evaluators.strategy import (
    evaluate_with_equivalence,
    evaluate_with_nl_judge,
    evaluate_with_equivalence_judge,
)
from llm.llm import LLM
from llm.util import normalize_response
from task_loading.task_loading import load_all_tasks
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed


def evaluate_answer(
    task: Task,
    method: Literal["verifier", "nl-judge", "equivalence-judge", "nl-judge-base"],
    timeout_ms: int
) -> SolverReply:
    if method == "nl-judge" or method == "nl-judge-base":
        use_base_answer = method == "nl-judge-base"
        return evaluate_with_nl_judge(task, model, correct_treshold=2, use_base_answer=use_base_answer)
    elif method == "equivalence-judge":
        return evaluate_with_equivalence_judge(task, model, correct_treshold=2)
    elif method == "verifier":
        return evaluate_with_equivalence(task, model, timeout_ms=timeout_ms)
    else:
        raise ValueError(f"Unknown method: {method}")


def collect_llm_responses(task: Task, model: LLM) -> Task:
    """Retrieves LLM responses and sets task.llm_solution to the normalized LLM response."""
    llm_response = model.call(
        prompt=task.natural_language,
        problem_id=task.id_in_domain,
        problem_domain=task.domain,
    )
    llm_answer = normalize_response(llm_response)
    task.llm_solution = llm_answer
    return task


def evaluate_tasks(
    eval_method: Literal["verifier", "nl-judge", "equivalence-judge", "nl-judge-base"],
    available_cores: int,
    evaluation_threads: int,
    tasks: list[Task],
    timeout_ms: int
) -> ResultCounter:
    results = ResultCounter()
    if eval_method == "nl-judge" or eval_method == "equivalence-judge" or eval_method == "nl-judge-base":
        evaluation_threads = available_cores * 4
        if available_cores > 1:
            with ThreadPoolExecutor(max_workers=evaluation_threads) as executor:
                future_to_task = {
                    executor.submit(evaluate_answer, task, eval_method, timeout_ms): task
                    for task in tasks
                }
                with tqdm(total=len(tasks), desc="Evaluating tasks") as pbar:
                    for future in as_completed(future_to_task):
                        reply = future.result()
                        task = future_to_task[future]
                        results.record(task.domain, task, reply)

                        if reply.verdict == "success":
                            status = "✓"
                        elif reply.verdict == "unknown":
                            status = "?"
                        else:
                            status = "✗"

                        pbar.set_postfix_str(
                            f"{task.domain.value}/{task.id_in_domain} | {status}"
                        )
                        pbar.update(1)
        else:
            with tqdm(total=len(tasks), desc="Evaluating tasks") as pbar:
                for task in tasks:
                    reply = evaluate_answer(task, eval_method, timeout_ms)
                    results.record(task.domain, task, reply)

                    if reply.verdict == "success":
                        status = "✓"
                    elif reply.verdict == "unknown":
                        status = "?"
                    else:
                        status = "✗"

                    pbar.set_postfix_str(
                        f"{task.domain.value}/{task.id_in_domain} | {status}"
                    )
                    pbar.update(1)
    else:
        # For equivalence method, we run Coq tasks and non-Coq tasks separately
        # This is necessary because CoqHammer launches multiple threads internally which leads to timeouts more often
        coq_tasks = [task for task in tasks if task.domain == Domain.Coq]
        non_coq_tasks = [task for task in tasks if task.domain != Domain.Coq]

        with tqdm(total=len(tasks), desc="Evaluating tasks") as pbar:
            if non_coq_tasks and available_cores > 1:
                with ThreadPoolExecutor(max_workers=evaluation_threads) as executor:
                    future_to_task = {
                        executor.submit(evaluate_answer, task, eval_method, timeout_ms): task
                        for task in non_coq_tasks
                    }
                    for future in as_completed(future_to_task):
                        reply = future.result()
                        task = future_to_task[future]
                        results.record(task.domain, task, reply)

                        if reply.verdict == "success":
                            status = "✓"
                        elif reply.verdict == "unknown":
                            status = "?"
                        else:
                            status = "✗"

                        pbar.set_postfix_str(
                            f"{task.domain.value}/{task.id_in_domain} | {status}"
                        )
                        pbar.update(1)
            elif non_coq_tasks and available_cores == 1:
                for task in non_coq_tasks:
                    reply = evaluate_answer(task, eval_method, timeout_ms)
                    results.record(task.domain, task, reply)

                    if reply.verdict == "success":
                        status = "✓"
                    elif reply.verdict == "unknown":
                        status = "?"
                    else:
                        status = "✗"

                    pbar.set_postfix_str(
                        f"{task.domain.value}/{task.id_in_domain} | {status}"
                    )
                    pbar.update(1)

            for task in coq_tasks:
                reply = evaluate_answer(task, eval_method, timeout_ms)
                results.record(task.domain, task, reply)

                if reply.verdict == "success":
                    status = "✓"
                elif reply.verdict == "unknown":
                    status = "?"
                else:
                    status = "✗"

                pbar.set_postfix_str(
                    f"{task.domain.value}/{task.id_in_domain} | {status}"
                )
                pbar.update(1)
    print()
    return results


def parallel_collect_llm_responses(
    llm_response_threads: int, model: LLM, tasks: list[Task]
):
    with ThreadPoolExecutor(max_workers=llm_response_threads) as executor:
        future_to_task = {
            executor.submit(collect_llm_responses, task, model): task for task in tasks
        }
        with tqdm(total=len(tasks), desc="Collecting LLM responses") as pbar:
            for future in as_completed(future_to_task):
                task = future.result()
                pbar.set_postfix_str(f"{task.domain.value}/{task.id_in_domain}")
                pbar.update(1)

if __name__ == "__main__":
    models_and_efforts = [
        ("openai/gpt-5", "medium"),
        ("openai/gpt-5-mini", "medium"),
        ("openai/gpt-oss-120b", "medium"),
        ("openai/gpt-oss-20b", "medium"),

        ("anthropic/claude-sonnet-4", None),

        ("google/gemini-2.5-flash", "medium"),
        ("google/gemini-2.5-pro", None),

        ("qwen/qwen3-next-80b-a3b-thinking", "medium"),
        ("qwen/qwen3-next-80b-a3b-instruct", None)
    ]
    timeout_ms = 4_000
    max_tokens = 20_000
    temperature = 1.0

    for model_name, reasoning_effort in models_and_efforts:
        tasks = load_all_tasks()

        print(f"Total tasks loaded: {len(tasks)}")
        domain_counts = {}
        for task in tasks:
            domain = task.domain.value
            domain_counts[domain] = domain_counts.get(domain, 0) + 1
        for domain, count in domain_counts.items():
            print(f"{domain}: {count}")

        print()

        available_cores = os.cpu_count() or 1
        llm_response_threads = 4 * available_cores
        evaluation_threads = max(1, available_cores // 3)

        print(f"{available_cores} cores available")
        print(f"LLM response threads: {llm_response_threads}")
        print(f"Evaluation threads: {evaluation_threads}")
        print('Model name:', model_name, 'reasoning effort:', reasoning_effort)

        if reasoning_effort:
            model = LLM(
                model_name=model_name,
                temperature=temperature,
                max_tokens=max_tokens,
                reasoning_effort=reasoning_effort,
            )
        else:
            model = LLM(
                model_name=model_name,
                temperature=temperature,
                max_tokens=max_tokens,
            )

        print()
        print("Retrieving LLM responses for all tasks...")
        parallel_collect_llm_responses(llm_response_threads, model, tasks)

        print("\nEvaluating all tasks...")

        print("\nEvaluating BASE ANSWER using Natural Language Judge:")
        results_nl_base_judge = evaluate_tasks(
            "nl-judge-base",
            available_cores,
            evaluation_threads,
            tasks,
            timeout_ms
        )

        print("\nEvaluating using Natural Language Judge:")
        results_nl_judge = evaluate_tasks(
            "nl-judge",
            available_cores,
            evaluation_threads,
            tasks,
            timeout_ms
        )

        print("\nEvaluating using Equivalence Judge:")
        results_equivalence_judge = evaluate_tasks(
            "equivalence-judge",
            available_cores,
            evaluation_threads,
            tasks,
            timeout_ms
        )

        print("\nEvaluating using Verifiers:")
        results_verifier = evaluate_tasks(
            "verifier",
            available_cores,
            evaluation_threads,
            tasks,
            timeout_ms
        )

        overall_results = OverallResults(
            results_verifier=results_verifier,
            results_nl_judge=results_nl_judge,
            results_equivalence_judge=results_equivalence_judge,
            results_nl_base_judge=results_nl_base_judge
        )

        overall_results.summarize_all()
        overall_results.write_csvs(model_name.split('/')[1], 'out')

        print('\n' * 8)
