import os
import numpy as np
import json
from pathlib import Path
from lcb_runner.benchmarks import load_generation_dataset, CodeGenerationProblem
from lcb_runner.evaluation import codegen_metrics


dataset = load_generation_dataset()

dataset = sorted(dataset, key=lambda x: x.question_id)


def check_model(model_key):
    results_root = Path(os.environ.get("LCB_OLD_RESULTS_ROOT", "run_models_outputs")).expanduser()
    path = results_root / model_key / "chat_0.2_checked.json"
    with open(path) as f:
        old_results = json.load(f)
    old_results = sorted(old_results, key=lambda x: x["question_id"])
    assert old_results[0]["question_id"] == dataset[0].question_id

    def debug(idx):
        codegen_metrics(
            [dataset[idx].get_evaluation_sample()],
            [old_results[idx]["code_list"][:1]],
            debug=True,
        )

    def run(idx):
        return codegen_metrics(
            [dataset[idx].get_evaluation_sample()],
            [old_results[idx]["code_list"]],
        )

    debug_idx = os.getenv("LCB_DEBUG_INDEX")
    if debug_idx is not None:
        debug(int(debug_idx))
        return
    # debug(196)
    # debug(352)

    metrics = codegen_metrics(
        [d.get_evaluation_sample() for d in dataset],
        [r["code_list"] for r in old_results],
        num_process_evaluate=12,
    )
    old_pass1 = np.mean([np.mean(r["pass1_list"]) for r in old_results])

    print(old_pass1)
    print(metrics[0]["pass@1"])

    for idx in range(400):
        old_pass1 = np.mean(old_results[idx]["pass1_list"])
        new_pass1 = metrics[0]["detail"]["pass@1"][idx]
        if not abs(old_pass1 - new_pass1) < 1e-4:
            print(idx, old_pass1, new_pass1)


if __name__ == "__main__":
    # Example usage:
    #   LCB_OLD_RESULTS_ROOT=/path/to/run_models_outputs python old_results_check.py
    #   LCB_DEBUG_INDEX=380 python old_results_check.py
    model_keys = [
        "Claude-3-Opus",
        "GPT-4-0613",
        "Mistral-Large",
        "Claude-3-Sonnet",
        "GPT-3.5-Turbo-0301",
        "Gemini-Pro",
    ]
    for model_key in model_keys:
        check_model(model_key)
