import pandas as pd
import matplotlib.pyplot as plt
import os
import json
from config.config import RESULTS_TRAIN, PREDICTIONS_TRAIN

def load_coverage_data(model, embedding, base_path):
    file_path = os.path.join(base_path, f"{model}__{embedding}", "coverage_accuracy_sweep.csv")
    return pd.read_csv(file_path, sep=",")


def load_model_threshold(model, summary_file):
    thresholds = pd.read_csv(summary_file)
    row = thresholds[thresholds["model_name"] == model]
    if row.empty:
        raise ValueError(f"Threshold not found for model: {model}")
    return float(row["accuracy"].values[0])


def plot_multiple_N(df, N_values, save_path):
    plt.figure(figsize=(10, 6))
    colors = ["steelblue", "red", "orange"]

    for i, N in enumerate(N_values):
        subset = df[df["N"] == N]
        plt.scatter(
            subset["coverage"],
            subset["accuracy"],
            s=4,
            label=f"N={N}",
            color=colors[i % len(colors)]
        )

    plt.xlabel("Coverage", fontsize=14)
    plt.ylabel("Accuracy", fontsize=14)
    plt.title("Accuracy vs Coverage for Different N values", fontsize=16)
    plt.xlim(left=0.0)
    plt.ylim(bottom=0.6)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(True, linestyle=':', linewidth=0.8)
    plt.tight_layout()

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close()


def plot_single_N_with_fill(df, chosen_N, value_threshold, save_path):
    subset = df[df["N"] == chosen_N].sort_values("coverage")

    # Pad to x=1.0 if missing
    if subset["coverage"].max() < 1.0:
        last_cov = subset["coverage"].iloc[-1]
        last_acc = subset["accuracy"].iloc[-1]
        subset = pd.concat([
            subset,
            pd.DataFrame({"coverage": [1.0], "accuracy": [last_acc], "N": [chosen_N]})
        ], ignore_index=True).sort_values("coverage")

    coverage = subset["coverage"].values
    accuracy = subset["accuracy"].values

    fig, ax = plt.subplots(figsize=(10, 6))

    ax.plot(
        coverage,
        accuracy,
        linestyle='-',
        linewidth=1,
        color='steelblue',
        label="Confidence Curve"
    )

    ax.fill_between(
        coverage,
        value_threshold,
        1.0,
        where=coverage >= 0.1,
        interpolate=False,
        color='steelblue',
        alpha=0.25,
        label="Maximum Confidence Gain"
    )

    ax.fill_between(
        coverage,
        value_threshold,
        accuracy,
        where=(coverage >= 0.1) & (accuracy >= value_threshold),
        interpolate=False,
        color='steelblue',
        alpha=0.5,
        label="Confidence Gain"
    )

    ax.axvline(x=0.1, color='steelblue', linestyle='--', linewidth=0.8)
    ax.axhline(
        y=value_threshold,
        color='darkred',
        linestyle='--',
        linewidth=1.5,
        label="Benchmark performance"
    )

    ax.set_xlabel("Coverage", fontsize=14)
    ax.set_ylabel("Accuracy", fontsize=14)
    ax.set_title(f"Accuracy vs Coverage for N = {chosen_N}", fontsize=16)
    ax.set_xlim(left=0.0, right=1.0)
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.grid(True, linestyle=':', linewidth=0.8)
    ax.legend(fontsize=12)
    plt.tight_layout()

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close()





# === Parameters ===
model = "resnet101"
embedding = "dinov2_b"

# === Paths ===
coverage_dir = os.path.join(RESULTS_TRAIN, "coverage_accuracy")
summary_file = os.path.join(PREDICTIONS_TRAIN, "summary.csv")
image_dir = os.path.join(RESULTS_TRAIN, "images")

# === Load data ===
df = load_coverage_data(model, embedding, coverage_dir)
value_threshold = load_model_threshold(model, summary_file)

# === Plot and Save ===
plot_multiple_N(
    df,
    N_values=[1, 100, 540],
    save_path=os.path.join(image_dir, "accuracy_vs_coverage_multiple_N.png")
)

plot_single_N_with_fill(
    df,
    chosen_N=1,
    value_threshold=value_threshold,
    save_path=os.path.join(image_dir, "accuracy_vs_coverage_N10.png")
)
