#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
from pathlib import Path
import glob
import pandas as pd
import matplotlib.pyplot as plt
from config.config import RESULTS_TEST_INTERNAL, RESULTS_TEST_EXTERNAL

# --- Fixed embedding model ---
EMBEDDING_MODEL_NAME = "dinov2_b"

# Pretty labels for legend
MODEL_NAME_MAP = {
    "resnet50": "ResNet-50",
    "resnet101": "ResNet-101",
    "shufflenet_v2_x1_0": "ShuffleNet-V2 x1.0",
    "deit_tiny_patch16_224": "DeiT-Tiny/16",
    "deit_small_patch16_224": "DeiT-Small/16",
    "deit_base_patch16_224": "DeiT-Base/16",
}
PREDICTION_MODELS = list(MODEL_NAME_MAP.keys())

SAVE_FIG = True
COVERAGE_STEP = 0.025  # bin width for coverage aggregation

# --- Plot style ---
TITLE_FONTSIZE = 16
LABEL_FONTSIZE = 14
TICK_FONTSIZE = 12
GRID_LINESTYLE = ':'
GRID_LINEWIDTH = 0.8
LINEWIDTH = 1.0
STD_ALPHA = 0.25  # lighter shading
STYLE_COLORS = ["steelblue", "red", "orange", "green", "purple", "brown"]


def _results_root(dataset_type: str) -> Path:
    if dataset_type.upper() == "EXTERNAL":
        return Path(RESULTS_TEST_EXTERNAL)
    elif dataset_type.upper() == "INTERNAL":
        return Path(RESULTS_TEST_INTERNAL)
    else:
        raise ValueError("dataset_type must be 'INTERNAL' or 'EXTERNAL'")


def _load_subset_csvs_for_model(results_root: Path, model_name: str, embedding_name: str) -> pd.DataFrame:
    folder = results_root / "coverage_accuracy_best_N" / f"{model_name}__{embedding_name}"
    if not folder.exists():
        raise FileNotFoundError(f"Folder not found: {folder}")

    csv_paths = sorted(glob.glob(str(folder / "*.csv")))
    if not csv_paths:
        raise FileNotFoundError(f"No CSV files found in: {folder}")

    frames = []
    for p in csv_paths:
        df = pd.read_csv(p, sep=None, engine="python")
        df.columns = [c.strip() for c in df.columns]
        keep = [c for c in ["N", "L", "coverage", "accuracy", "subset"] if c in df.columns]
        df = df[keep].copy()

        df["coverage"] = pd.to_numeric(df.get("coverage"), errors="coerce")
        df["accuracy"] = pd.to_numeric(df.get("accuracy"), errors="coerce")
        if "subset" in df.columns:
            df["subset"] = df["subset"].astype(str).str.strip()
        else:
            df["subset"] = "ALL"

        df = df.dropna(subset=["coverage", "accuracy"])
        frames.append(df)

    return pd.concat(frames, ignore_index=True)


def _aggregate_mean_std_by_coverage_binned(df: pd.DataFrame, step: float) -> pd.DataFrame:
    df = df.copy()
    df["coverage_bin"] = (df["coverage"] / step).round().astype(int) * step
    df["coverage_bin"] = df["coverage_bin"].clip(lower=0.0, upper=1.0).round(6)

    agg = (
        df.groupby("coverage_bin", as_index=False)
          .agg(acc_mean=("accuracy", "mean"),
               acc_std=("accuracy", "std"),
               count=("accuracy", "count"))
          .sort_values("coverage_bin")
          .reset_index(drop=True)
    )
    return agg.rename(columns={"coverage_bin": "coverage"})


def run_coverage_accuracy_best_N(dataset_type: str):
    results_root = _results_root(dataset_type)

    fig, ax = plt.subplots(figsize=(10, 6))
    any_plotted = False

    for idx, model in enumerate(PREDICTION_MODELS):
        try:
            df = _load_subset_csvs_for_model(results_root, model, EMBEDDING_MODEL_NAME)
        except Exception as e:
            print(f"[Skip] {model}: {e}")
            continue

        agg = _aggregate_mean_std_by_coverage_binned(df, COVERAGE_STEP)

        color = STYLE_COLORS[idx % len(STYLE_COLORS)]
        label = MODEL_NAME_MAP.get(model, model.replace("_", " "))

        ax.plot(
            agg["coverage"].values,
            agg["acc_mean"].values,
            linestyle='-',
            linewidth=LINEWIDTH,
            color=color,
            label=label
        )

        mask = agg["count"] >= 2
        if mask.any():
            upper = agg.loc[mask, "acc_mean"] + agg.loc[mask, "acc_std"]
            lower = agg.loc[mask, "acc_mean"] - agg.loc[mask, "acc_std"]
            ax.fill_between(
                agg.loc[mask, "coverage"].values,
                lower.values,
                upper.values,
                alpha=STD_ALPHA,
                color=color
            )

        any_plotted = True

    if not any_plotted:
        raise RuntimeError("No models were plotted — check your paths and CSV availability.")

    ax.set_xlabel("Coverage", fontsize=LABEL_FONTSIZE)
    ax.set_ylabel("Accuracy", fontsize=LABEL_FONTSIZE)
    ax.set_title("Accuracy vs Coverage (mean ± std across subsets)\nEmbedding: DINO-V2 ViT-B/14",
                 fontsize=TITLE_FONTSIZE)
    ax.set_xlim(left=0.05, right=1.0)
    ax.set_ylim(bottom=0.15, top=1.0)
    ax.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)
    ax.grid(True, linestyle=GRID_LINESTYLE, linewidth=GRID_LINEWIDTH)
    ax.legend(fontsize=12)
    plt.tight_layout()

    fig_path = results_root / "accuracy_vs_coverage__dinov2_b.png"
    if SAVE_FIG:
        plt.savefig(fig_path, dpi=300)
        print(f"Saved figure to: {fig_path}")

    plt.close(fig)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Run coverage/accuracy evaluation for INTERNAL or EXTERNAL datasets."
    )
    parser.add_argument(
        "dataset_type",
        nargs="?",
        default="INTERNAL",
        choices=["INTERNAL", "EXTERNAL"],
        help="Dataset type to process (default: INTERNAL)"
    )
    args = parser.parse_args()

    run_coverage_accuracy_best_N(args.dataset_type)
