import pandas as pd
import numpy as np
import os
import argparse
from collections import defaultdict
import matplotlib.pyplot as plt
from config.config import (
    BASE_DIR,
    PREDICTIONS_TEST_INTERNAL,
    PREDICTIONS_TEST_EXTERNAL,
    DISTANCES_TEST_INTERNAL,
    DISTANCES_TEST_EXTERNAL,
    RESULTS_TEST_INTERNAL,
    RESULTS_TEST_EXTERNAL,
    COMBINATION_EXTERNAL_TEST,
    COMBINATION_INTERNAL_TEST
)

PREDICTION_MODELS = [
    "resnet50",
    "resnet101",
    "shufflenet_v2_x1_0",
    "deit_tiny_patch16_224",
    "deit_small_patch16_224",
    "deit_base_patch16_224"
]

EMBEDDING_MODELS_ORDER = [
    "DINO-V1 ViT-B/16",
    "DINO-V1 ViT-B/8",
    "DINO-V1 ViT-S/16",
    "DINO-V2 ViT-B/14",
    "DINO-V2 ViT-S/14",
    "MobileNet-V2"
]

EMBEDDING_DISK_NAME_MAP = {
    "DINO-V1 ViT-B/16": "dinov1_b",
    "DINO-V1 ViT-B/8": "dinov1_b8",
    "DINO-V1 ViT-S/16": "dinov1_s",
    "DINO-V2 ViT-B/14": "dinov2_b",
    "DINO-V2 ViT-S/14": "dinov2_s",
    "MobileNet-V2": "mobilenet_v2",
}

# (best_N, score) tuples; only best_N is used here
TABLE_DATA = {
    "resnet50": [(1, 0.4107), (1, 0.4487), (1, 0.3819), (19, 0.4841), (2, 0.4304), (1, 0.2138)],
    "resnet101": [(1, 0.3887), (1, 0.4251), (1, 0.3586), (50, 0.4616), (4, 0.4104), (1, 0.1934)],
    "shufflenet_v2_x1_0": [(1, 0.4106), (1, 0.4270), (1, 0.3943), (7, 0.4110), (4, 0.3916), (1, 0.2441)],
    "deit_tiny_patch16_224": [(1, 0.4096), (1, 0.4332), (1, 0.3869), (23, 0.4376), (4, 0.4082), (1, 0.2250)],
    "deit_small_patch16_224": [(1, 0.4065), (1, 0.4466), (1, 0.3790), (14, 0.4612), (3, 0.4177), (1, 0.2000)],
    "deit_base_patch16_224": [(1, 0.4040), (1, 0.4422), (1, 0.3729), (19, 0.4728), (4, 0.4197), (1, 0.2001)]
}

def find_l_for_coverage(train_df: pd.DataFrame, target_coverage: float):
    ok = train_df.loc[train_df["coverage"] >= target_coverage, "L"]
    return ok.max() if not ok.empty else None

def run_modelwise_confident_combination(dataset_type: str):
    print("=== Starting model-wise evaluation ===")

    if dataset_type == "EXTERNAL":
        predictions_dir = PREDICTIONS_TEST_EXTERNAL
        distances_dir = DISTANCES_TEST_EXTERNAL
        results_dir = RESULTS_TEST_EXTERNAL
        combo_dir = COMBINATION_EXTERNAL_TEST
    else:
        predictions_dir = PREDICTIONS_TEST_INTERNAL
        distances_dir = DISTANCES_TEST_INTERNAL
        results_dir = RESULTS_TEST_INTERNAL
        combo_dir = COMBINATION_INTERNAL_TEST

    pred_dfs = {}
    for model in PREDICTION_MODELS:
        pred_path = os.path.join(predictions_dir, model + ".csv")
        if not os.path.exists(pred_path):
            raise FileNotFoundError(f"Missing predictions CSV: {pred_path}")
        pred_df = pd.read_csv(pred_path, sep=None, engine="python")
        pred_df.columns = pred_df.columns.str.strip()
        if "image_id" not in pred_df.columns:
            raise KeyError(f"'image_id' column not found in {pred_path}")
        pred_df.set_index("image_id", inplace=True)
        pred_df.index = pred_df.index.astype(str).str.replace(".jpg", "", regex=False).str.strip()
        pred_dfs[model] = pred_df

    total_ids = set(pred_dfs[PREDICTION_MODELS[0]].index)
    final_assigned = {img_id: False for img_id in total_ids}

    coverage_range = np.round(np.arange(0.02, 1.01, 0.04), 2)

    all_results = []

    for model in PREDICTION_MODELS:
        print(f"\n== Evaluating model: {model}")

        scores = TABLE_DATA[model]
        paired = list(zip(EMBEDDING_MODELS_ORDER, scores))
        sorted_embeddings = sorted(paired, key=lambda x: x[1][1], reverse=True)

        threshold_lookup = defaultdict(dict)

        for emb_name, (best_n, _) in sorted_embeddings:
            disk_name = EMBEDDING_DISK_NAME_MAP[emb_name]
            result_path = os.path.join(
                results_dir, "coverage_accuracy_best_N", f"{model}__{disk_name}", "coverage_accuracy_best_N.csv"
            )
            if not os.path.exists(result_path):
                print(f"[warn] Missing results: {result_path}")
                continue

            df = pd.read_csv(result_path)
            if "N" not in df.columns or "L" not in df.columns or "coverage" not in df.columns:
                print(f"[warn] Required columns missing in {result_path}")
                continue

            df_n = df[df["N"] == best_n]
            if df_n.empty:
                print(f"[warn] No rows for N={best_n} in {result_path}")
                continue

            numeric_cols = df_n.select_dtypes(include='number').columns
            df_grouped = df_n.groupby("L", as_index=False)[numeric_cols].mean()

            for cov_thresh in coverage_range:
                cov_key = float(round(float(cov_thresh), 2))
                l_val = find_l_for_coverage(df_grouped, cov_key)
                if l_val is not None:
                    threshold_lookup[emb_name][cov_key] = float(l_val)

        used_embeddings = {emb for emb, covmap in threshold_lookup.items() if covmap}
        if not used_embeddings:
            print(f"[warn] No usable embeddings for model {model}; skipping model.")
            continue

        sim_dfs = {}
        for embed in used_embeddings:
            disk_name = EMBEDDING_DISK_NAME_MAP[embed]
            sim_path = os.path.join(distances_dir, f"{disk_name}_cosine_similarities.csv")
            if os.path.exists(sim_path):
                df = pd.read_csv(sim_path, index_col=0)
                df.index = df.index.astype(str).str.strip()
                df.columns = df.columns.astype(str).str.strip()
                sim_dfs[embed] = df
            else:
                print(f"[warn] Missing similarities: {sim_path}")

        pred_df = pred_dfs[model]
        pred_col = f"{model}_pred_class"
        true_col = f"{model}_true_class"
        if pred_col not in pred_df.columns or true_col not in pred_df.columns:
            raise KeyError(f"Expected columns '{pred_col}'/'{true_col}' not found in predictions for {model}.")
        if "subset" not in pred_df.columns:
            raise KeyError(f"Expected 'subset' column not found in predictions for {model}.")

        subset_map = pred_df["subset"].to_dict()

        for cov_thresh in coverage_range:
            cov_key = float(round(float(cov_thresh), 2))
            assigned_preds = {}
            assigned_truth = {}

            for img_id in total_ids:
                for emb_name, (_best_n, _score) in sorted_embeddings:
                    l_thresh = threshold_lookup.get(emb_name, {}).get(cov_key)
                    if l_thresh is None:
                        continue
                    sim_df = sim_dfs.get(emb_name)
                    if sim_df is None or img_id not in sim_df.index:
                        continue

                    sims = sim_df.loc[img_id]
                    if img_id in sims.index:
                        sims = sims.drop(labels=[img_id])
                    sims = sims.astype(float)

                    if np.any(sims >= l_thresh):
                        assigned_preds[img_id] = pred_df.loc[img_id, pred_col]
                        assigned_truth[img_id] = pred_df.loc[img_id, true_col]
                        final_assigned[img_id] = True
                        break

            # --- Mean/std among subsets ---
            subset_stats = []
            for subset in ["subset_A", "subset_B", "subset_C"]:
                subset_ids = [iid for iid in total_ids if subset_map.get(iid) == subset]
                assigned_subset_ids = [iid for iid in assigned_preds if subset_map.get(iid) == subset]

                cov = (len(assigned_subset_ids) / len(subset_ids)) if subset_ids else 0.0

                if assigned_subset_ids:
                    acc = float(np.mean([assigned_preds[iid] == assigned_truth[iid] for iid in assigned_subset_ids]))
                else:
                    acc = np.nan

                subset_stats.append((acc, cov))

            acc_vals = np.array([acc for acc, _ in subset_stats if not np.isnan(acc)])
            cov_vals = np.array([cov for _, cov in subset_stats])

            mean_acc = float(np.mean(acc_vals)) if acc_vals.size else 0.0
            std_acc = float(np.std(acc_vals)) if acc_vals.size else 0.0
            mean_cov = float(np.mean(cov_vals)) if cov_vals.size else 0.0
            std_cov = float(np.std(cov_vals)) if cov_vals.size else 0.0

            print(
                f"Model: {model} | Coverage Target: {cov_key:.2f} | "
                f"Mean Acc: {mean_acc:.4f} | Std Acc: {std_acc:.4f} | "
                f"Mean Cov: {mean_cov:.4f} | Std Cov: {std_cov:.4f}"
            )

            all_results.append({
                "model": model,
                "coverage_target": cov_key,
                "mean_accuracy": mean_acc,
                "std_accuracy": std_acc,
                "mean_subset_coverage": mean_cov,
                "std_subset_coverage": std_cov
            })

    output_dir = os.path.join(results_dir, combo_dir)
    os.makedirs(output_dir, exist_ok=True)
    results_df = pd.DataFrame(all_results)
    out_path = os.path.join(output_dir, f"final_modelwise_coverage_accuracy_results.csv")
    results_df.to_csv(out_path, index=False)
    print(f"\nResults saved to {out_path}")

    # --- Plot ---
    plt.figure(figsize=(10, 6))
    for model in PREDICTION_MODELS:
        df_model = results_df[results_df["model"] == model]
        plt.errorbar(df_model["coverage_target"], df_model["mean_accuracy"], yerr=df_model["std_accuracy"],
                     label=model, capsize=3, marker='o', linestyle='-')
    plt.xlabel("Coverage Target")
    plt.ylabel("Mean Accuracy (across subsets)")
    plt.title(f"Model-wise Coverage vs Accuracy ({dataset_type})")
    plt.legend()
    plt.grid(True)
    plot_path = os.path.join(output_dir, f"coverage_accuracy_plot_{dataset_type}.png")
    plt.savefig(plot_path, dpi=300)
    print(f"Plot saved to {plot_path}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Run model-wise confident combination evaluation for INTERNAL/EXTERNAL datasets."
    )
    parser.add_argument(
        "dataset_type",
        nargs="?",
        default="INTERNAL",
        choices=["INTERNAL", "EXTERNAL"],
        help="Dataset type: INTERNAL or EXTERNAL (default: INTERNAL)"
    )
    args = parser.parse_args()
    run_modelwise_confident_combination(args.dataset_type)
