import pandas as pd
import numpy as np
import os
import argparse
from config.config import (
    PREDICTIONS_TEST_INTERNAL,
    PREDICTIONS_TEST_EXTERNAL,
    DISTANCES_TEST_INTERNAL,
    DISTANCES_TEST_EXTERNAL,
    RESULTS_TEST_INTERNAL,
    RESULTS_TEST_EXTERNAL,
    COMBINATION_EXTERNAL_TEST,
    COMBINATION_INTERNAL_TEST
)

PREDICTION_MODELS = [
    "resnet50",
    "resnet101",
    "shufflenet_v2_x1_0",
    "deit_tiny_patch16_224",
    "deit_small_patch16_224",
    "deit_base_patch16_224"
]

EMBEDDING_MODELS_ORDER = [
    "DINO-V1 ViT-B/16",
    "DINO-V1 ViT-B/8",
    "DINO-V1 ViT-S/16",
    "DINO-V2 ViT-B/14",
    "DINO-V2 ViT-S/14",
    "MobileNet-V2"
]

EMBEDDING_DISK_NAME_MAP = {
    "DINO-V1 ViT-B/16": "dinov1_b",
    "DINO-V1 ViT-B/8": "dinov1_b8",
    "DINO-V1 ViT-S/16": "dinov1_s",
    "DINO-V2 ViT-B/14": "dinov2_b",
    "DINO-V2 ViT-S/14": "dinov2_s",
    "MobileNet-V2": "mobilenet_v2",
}

TABLE_DATA = {
    "resnet50": [(1, 0.4107), (1, 0.4487), (1, 0.3819), (19, 0.4841), (2, 0.4304), (1, 0.2138)],
    "resnet101": [(1, 0.3887), (1, 0.4251), (1, 0.3586), (50, 0.4616), (4, 0.4104), (1, 0.1934)],
    "shufflenet_v2_x1_0": [(1, 0.4106), (1, 0.4270), (1, 0.3943), (7, 0.4110), (4, 0.3916), (1, 0.2441)],
    "deit_tiny_patch16_224": [(1, 0.4096), (1, 0.4332), (1, 0.3869), (23, 0.4376), (4, 0.4082), (1, 0.2250)],
    "deit_small_patch16_224": [(1, 0.4065), (1, 0.4466), (1, 0.3790), (14, 0.4612), (3, 0.4177), (1, 0.2000)],
    "deit_base_patch16_224": [(1, 0.4040), (1, 0.4422), (1, 0.3729), (19, 0.4728), (4, 0.4197), (1, 0.2001)]
}

def find_l_for_coverage(train_df, target_coverage):
    # Assumes coverage decreases (or roughly decreases) as L increases
    df = train_df.sort_values("L")
    mask = df["coverage"] >= target_coverage
    if not mask.any():
        return None  # can't reach that coverage
    return df.loc[mask, "L"].max()

def run_modelwise_confident_combination(dataset_type: str):
    print("=== Starting model-wise evaluation ===")

    # Pick paths
    if dataset_type == "EXTERNAL":
        predictions_dir = PREDICTIONS_TEST_EXTERNAL
        distances_dir = DISTANCES_TEST_EXTERNAL
        results_dir = RESULTS_TEST_EXTERNAL
        combo_dir = COMBINATION_EXTERNAL_TEST
    else:
        predictions_dir = PREDICTIONS_TEST_INTERNAL
        distances_dir = DISTANCES_TEST_INTERNAL
        results_dir = RESULTS_TEST_INTERNAL
        combo_dir = COMBINATION_INTERNAL_TEST
    output_dir = combo_dir
    # Load predictions
    pred_dfs = {}
    for model in PREDICTION_MODELS:
        pred_path = os.path.join(predictions_dir, model + ".csv")
        pred_df = pd.read_csv(pred_path, sep=None, engine="python")
        pred_df.columns = pred_df.columns.str.strip()
        pred_df.set_index("image_id", inplace=True)
        pred_df.index = pred_df.index.str.replace(".jpg", "", regex=False).str.strip()
        pred_dfs[model] = pred_df

    total_ids = set(pred_dfs[PREDICTION_MODELS[0]].index)
    coverage_range = np.round(np.arange(0.04, 1.01, 0.02), 2)

    subset_rows = []

    for model in PREDICTION_MODELS:
        print(f"\n== Evaluating model: {model}")
        scores = TABLE_DATA[model]
        paired = list(zip(EMBEDDING_MODELS_ORDER, scores))
        sorted_embeddings = sorted(paired, key=lambda x: x[1][1], reverse=True)

        threshold_lookup = {}
        for emb_name, (best_n, _) in sorted_embeddings:
            disk_name = EMBEDDING_DISK_NAME_MAP[emb_name]
            result_path = os.path.join(
                results_dir, "coverage_accuracy_best_N", model + "__" + disk_name, "coverage_accuracy_best_N.csv"
            )
            if not os.path.exists(result_path):
                continue
            df = pd.read_csv(result_path)
            df_n_filtered = df[df["N"] == best_n]
            numeric_cols = df_n_filtered.select_dtypes(include='number').columns
            df_grouped = df_n_filtered.groupby("L", as_index=False)[numeric_cols].mean()
            for cov_thresh in coverage_range:
                l_val = find_l_for_coverage(df_grouped, cov_thresh)
                if l_val is not None:
                    threshold_lookup[(emb_name, cov_thresh)] = l_val

        used_embeddings = set(emb for (emb, _) in threshold_lookup.keys())
        sim_dfs = {}
        for embed in used_embeddings:
            disk_name = EMBEDDING_DISK_NAME_MAP[embed]
            sim_path = os.path.join(distances_dir, disk_name + "_cosine_similarities.csv")
            if os.path.exists(sim_path):
                df = pd.read_csv(sim_path, index_col=0)
                df.index = df.index.astype(str).str.strip()
                df.columns = df.columns.astype(str).str.strip()
                sim_dfs[embed] = df

        pred_df = pred_dfs[model]
        subset_map = pred_df["subset"].to_dict()

        for cov_thresh in coverage_range:
            assigned_preds = {}
            assigned_truth = {}
            for img_id in total_ids:
                for emb_name, (best_n, _) in sorted_embeddings:
                    l_thresh = threshold_lookup.get((emb_name, cov_thresh))
                    if l_thresh is None:
                        continue
                    sim_df = sim_dfs.get(emb_name)
                    if sim_df is None or img_id not in sim_df.index:
                        continue
                    sims = sim_df.loc[img_id].drop(labels=[img_id], errors="ignore")
                    sims = pd.to_numeric(sims, errors="coerce").fillna(-np.inf)
                    if (sims >= l_thresh).any():
                        assigned_preds[img_id] = pred_df.loc[img_id, f"{model}_pred_class"]
                        assigned_truth[img_id] = pred_df.loc[img_id, f"{model}_true_class"]
                        break

            for subset in ["subset_A", "subset_B", "subset_C"]:
                subset_ids = [i for i in total_ids if subset_map.get(i) == subset]
                subset_size = len(subset_ids)
                if subset_size == 0:
                    cov = np.nan
                    acc = np.nan
                else:
                    assigned_subset_ids = [i for i in assigned_preds if subset_map.get(i) == subset]
                    cov = len(assigned_subset_ids) / subset_size
                    if assigned_subset_ids:
                        acc = float(np.mean([assigned_preds[i] == assigned_truth[i] for i in assigned_subset_ids]))
                    else:
                        acc = np.nan

                subset_rows.append({
                    "model": model,
                    "subset": subset,
                    "coverage": cov,
                    "accuracy": acc
                })

    os.makedirs(output_dir, exist_ok=True)
    detailed_df = pd.DataFrame(subset_rows)
    out_path = os.path.join(output_dir, "modelwise_subset_coverage_accuracy.csv")
    detailed_df.to_csv(out_path, index=False)
    print(f"\nDetailed per-subset results saved to {out_path}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Run model-wise confident combination evaluation for INTERNAL/EXTERNAL datasets."
    )
    parser.add_argument(
        "dataset_type",
        nargs="?",
        default="INTERNAL",
        choices=["INTERNAL", "EXTERNAL"],
        help="Dataset type: INTERNAL or EXTERNAL (default: INTERNAL)"
    )
    args = parser.parse_args()
    run_modelwise_confident_combination(args.dataset_type)
