import asyncio
import json
import logging
import os

import numpy as np
from matplotlib import pyplot as plt
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss, roc_auc_score, roc_curve

from pipeline.blue_team.estimating_difficulty import (
    add_execution_difficulty_estimation_to_pipeline,
)
from src.code_evaluation.executor_standard import run_solutions
from src.datatypes.enums import DifficultyEstimationType
from src.model_querying.prompt_creation import (
    get_difficulty_problem_only,
    get_difficulty_problem_solution,
    get_two_shot_prompt_standard,
)
from src.model_querying.solution_extraction import (
    extract_difficulty,
    extract_solution_in_block,
)
from src.pipeline.pipeline import Pipeline, PipelineConfig
from src.tools.dataloaders import get_data_dir, load_problem_subset
from src.tools.path_utils import get_default_results_directory, get_root_directory

logger = logging.getLogger(__name__)


def get_pipeline(
    strong_model,
    weak_model,
    name=None,
    use_logprob=False,
    use_cache=True,
    num_problems=None,
    estimation_method="problem-only",
):
    ROOT_DIR = get_root_directory()
    DATA_DIR = ROOT_DIR / "data" / "APPS"
    TEST_DIR = DATA_DIR / "test"

    assert (
        strong_model is not None and weak_model is not None
    ), "Both strong and weak models must be provided"

    if name is None:
        pipeline_name = "difficulty_estimation"
    else:
        pipeline_name = f"difficulty_estimation-{name}"

    pipeline_config = PipelineConfig(
        pipeline_name,
        anthropic_num_threads=40,
        openai_fraction_rate_limit=0.99,
        num_problems=num_problems,
        use_cache=use_cache,
    )

    pipeline = Pipeline(pipeline_config)
    load_data = pipeline.add_load_data_step(
        "load_problems",
        load_problem_subset("INTERVIEW"),
        TEST_DIR,
    )
    get_weak_solutions = pipeline.add_query_step(
        "get_weak_solutions",
        weak_model,
        get_two_shot_prompt_standard,
        extract_solution_in_block,
        dependencies=[load_data],
    )

    evaluate_weak_solutions = pipeline.add_code_evaluation_step(
        "evaluate_weak_solutions",
        weak_model,
        run_solutions,
        dependencies=[get_weak_solutions],
    )

    if estimation_method == DifficultyEstimationType.PROBLEM_ONLY:
        get_difficulty = pipeline.add_query_step(
            "get_difficulty",
            weak_model,
            get_difficulty_problem_only,
            extract_difficulty,
            dependencies=[evaluate_weak_solutions],
        )
    elif estimation_method == DifficultyEstimationType.PROBLEM_SOLUTION:
        get_difficulty = pipeline.add_query_step(
            "get_difficulty_problem_solution",
            weak_model,
            get_difficulty_problem_solution,
            extract_difficulty,
            dependencies=[evaluate_weak_solutions],
        )
    elif estimation_method == DifficultyEstimationType.PROBLEM_SOLUTION_EXECUTION:
        get_difficulty = add_execution_difficulty_estimation_to_pipeline(
            pipeline, strong_model, weak_model, evaluate_weak_solutions
        )

    else:
        raise Exception(f"Unimplemented estimation method: {estimation_method}")

    return pipeline


def analyze(
    eval_results,
    binary_eval_results,
    difficulty_results,
    name,
    strong_model,
    weak_model,
):
    # Calculate accuracy
    accuracy = np.mean(
        [i == (j > 0.5) for i, j in zip(binary_eval_results, difficulty_results)]
    )

    # Calculate AUC
    auc = roc_auc_score(binary_eval_results, difficulty_results)

    # Calculate calibration (Brier score)
    calibration = brier_score_loss(binary_eval_results, difficulty_results)

    print(f"Accuracy: {accuracy:.4f}")
    print(f"AUC: {auc:.4f}")
    print(f"Calibration (Brier score): {calibration:.4f}")

    save_dir = f"{get_default_results_directory()}/{name}/merged_results/{strong_model}+{weak_model}"
    os.makedirs(save_dir, exist_ok=True)
    metrics = {"accuracy": accuracy, "auc": auc, "calibration": calibration}
    with open(f"{save_dir}/metrics.json", "w") as f:
        json.dump(metrics, f, indent=2)

    plt.figure(figsize=(10, 6))
    plt.hist(difficulty_results, bins=20, edgecolor="black")
    plt.title("Distribution of Difficulty Results")
    plt.xlabel("Difficulty")
    plt.ylabel("Frequency")
    plt.savefig(f"{save_dir}/difficulty_distribution.png")
    plt.close()

    fpr, tpr, _ = roc_curve(binary_eval_results, difficulty_results)
    with open(f"{save_dir}/roc_curve.json", "w") as f:
        json.dump({"fpr": fpr.tolist(), "tpr": tpr.tolist()}, f, indent=2)

    plt.figure(figsize=(10, 6))
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.title("ROC Curve")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.savefig(f"{save_dir}/roc_curve.png")
    plt.close()

    prob_true, prob_pred = calibration_curve(
        binary_eval_results, difficulty_results, n_bins=10
    )
    plt.figure(figsize=(10, 6))
    plt.plot(prob_pred, prob_true, marker="o")
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.title("Calibration Curve")
    plt.xlabel("Difficulty")
    plt.ylabel("Test Case Accuracy")
    plt.savefig(f"{save_dir}/calibration_curve.png")
    plt.close()


if __name__ == "__main__":
    strong_model = "gpt-4-0125-preview"
    weak_model = "gpt-3.5-turbo-1106"
    pipeline = get_pipeline(
        strong_model=strong_model,
        weak_model=weak_model,
        num_problems=200,
        estimation_method=DifficultyEstimationType.PROBLEM_SOLUTION_EXECUTION,
        name="problem_solution_execution",
    )
    results = asyncio.run(pipeline.run())
    eval_results = [
        np.mean([j["correct"] for j in i["test_cases"]])
        for i in results["evaluate_weak_solutions"].values()
    ]
    binary_eval_results = [i == 1 for i in eval_results]
    difficulty_results = [
        1 - (i["difficulty"] / 10) for i in results["get_difficulty"].values()
    ]
    analyze(
        eval_results,
        binary_eval_results,
        difficulty_results,
        pipeline.config.name,
        strong_model,
        weak_model,
    )
