import pandas as pd
import numpy as np
import os
import argparse
from collections import defaultdict
from scipy.interpolate import PchipInterpolator as _Spline

from config.config import (
    BASE_DIR,
    PREDICTIONS_TEST_INTERNAL,
    PREDICTIONS_TEST_EXTERNAL,
    DISTANCES_TEST_INTERNAL,
    DISTANCES_TEST_EXTERNAL,
    RESULTS_TEST_INTERNAL,
    RESULTS_TEST_EXTERNAL,
    COMBINATION_EXTERNAL_TEST,
    COMBINATION_INTERNAL_TEST
)

PREDICTION_MODELS = [
    "resnet50",
    "resnet101",
    "shufflenet_v2_x1_0",
    "deit_tiny_patch16_224",
    "deit_small_patch16_224",
    "deit_base_patch16_224"
]

EMBEDDING_MODELS_ORDER = [
    "DINO-V1 ViT-B/16",
    "DINO-V1 ViT-B/8",
    "DINO-V1 ViT-S/16",
    "DINO-V2 ViT-B/14",
    "DINO-V2 ViT-S/14",
    "MobileNet-V2"
]

EMBEDDING_DISK_NAME_MAP = {
    "DINO-V1 ViT-B/16": "dinov1_b",
    "DINO-V1 ViT-B/8": "dinov1_b8",
    "DINO-V1 ViT-S/16": "dinov1_s",
    "DINO-V2 ViT-B/14": "dinov2_b",
    "DINO-V2 ViT-S/14": "dinov2_s",
    "MobileNet-V2": "mobilenet_v2",
}

# (best_N, held-out score) chosen earlier; we use best_N only
TABLE_DATA = {
    "resnet50": [(1, 0.4107), (1, 0.4487), (1, 0.3819), (19, 0.4841), (2, 0.4304), (1, 0.2138)],
    "resnet101": [(1, 0.3887), (1, 0.4251), (1, 0.3586), (50, 0.4616), (4, 0.4104), (1, 0.1934)],
    "shufflenet_v2_x1_0": [(1, 0.4106), (1, 0.4270), (1, 0.3943), (7, 0.4110), (4, 0.3916), (1, 0.2441)],
    "deit_tiny_patch16_224": [(1, 0.4096), (1, 0.4332), (1, 0.3869), (23, 0.4376), (4, 0.4082), (1, 0.2250)],
    "deit_small_patch16_224": [(1, 0.4065), (1, 0.4466), (1, 0.3790), (14, 0.4612), (3, 0.4177), (1, 0.2000)],
    "deit_base_patch16_224": [(1, 0.4040), (1, 0.4422), (1, 0.3729), (19, 0.4728), (4, 0.4197), (1, 0.2001)]
}

SUBSETS = ["subset_A", "subset_B", "subset_C"]

# =============================================================
# Helpers
# =============================================================

def _linear_interp(xs: np.ndarray, ys: np.ndarray):
    def f(z):
        return np.interp(z, xs, ys)
    return f


def build_spline(x: np.ndarray, y: np.ndarray):
    """Return callable f(c) giving y at x=c using a shape-preserving cubic spline if possible.
    Falls back to linear interpolation. Clips output to [0,1]."""
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    m = np.isfinite(x) & np.isfinite(y)
    x, y = x[m], y[m]
    if x.size == 0:
        return lambda z: np.zeros_like(np.asarray(z, dtype=float))
    # sort and make x strictly increasing; if duplicate x, keep max y
    order = np.argsort(x)
    x, y = x[order], y[order]
    uniq_x = []
    uniq_y = []
    last_x = None
    best_y = -np.inf
    for xi, yi in zip(x, y):
        if last_x is None or xi != last_x:
            if last_x is not None:
                uniq_x.append(last_x)
                uniq_y.append(best_y)
            last_x = xi
            best_y = yi
        else:
            if yi > best_y:
                best_y = yi
    # append last
    if last_x is not None:
        uniq_x.append(last_x)
        uniq_y.append(best_y)
    x = np.asarray(uniq_x)
    y = np.asarray(uniq_y)

    if x.size >= 3 and _Spline is not None:
        try:
            spl = _Spline(x, y, extrapolate=True)
            def f(z):
                z = np.asarray(z, dtype=float)
                vals = spl(z)
                return np.clip(vals, 0.0, 1.0)
            return f
        except Exception:
            pass
    # fallback
    lin = _linear_interp(x, y)
    def f(z):
        z = np.asarray(z, dtype=float)
        return np.clip(lin(z), 0.0, 1.0)
    return f


def integrate_func(func, a=0.10, b=1.00, step=0.001):
    grid = np.arange(a, b + 1e-12, step)
    vals = func(grid)
    return float(np.trapz(vals, grid))


def _find_subset_columns(df: pd.DataFrame, subset: str):
    """Try to find coverage/accuracy columns for a given subset using flexible matching.
    Returns (coverage_col, accuracy_col) or (None, None) if not found."""
    s_norm = subset.lower().replace("_", "")
    cov_col = None
    acc_col = None
    for c in df.columns:
        c_norm = c.lower().replace("_", "")
        if "coverage" in c_norm and s_norm in c_norm:
            cov_col = c
        if "accuracy" in c_norm and s_norm in c_norm:
            acc_col = c
    if cov_col and acc_col:
        return cov_col, acc_col
    # common alternates like coverage_A / accuracy_A
    tag = subset.split("_")[-1].lower()  # 'a', 'b', 'c'
    for c in df.columns:
        c_norm = c.lower().replace("_", "")
        if ("coverage" in c_norm) and (tag in c_norm):
            cov_col = cov_col or c
        if ("accuracy" in c_norm) and (tag in c_norm):
            acc_col = acc_col or c
    if cov_col and acc_col:
        return cov_col, acc_col
    return None, None


# =============================================================
# Core logic
# =============================================================

def run_integral_report(dataset_type: str):
    print("=== Starting spline-integral evaluation ===")

    # Pick paths based on dataset type
    if dataset_type == "EXTERNAL":
        predictions_dir = PREDICTIONS_TEST_EXTERNAL
        results_dir = RESULTS_TEST_EXTERNAL
        combo_dir = COMBINATION_EXTERNAL_TEST
    else:  # INTERNAL
        predictions_dir = PREDICTIONS_TEST_INTERNAL
        results_dir = RESULTS_TEST_INTERNAL
        combo_dir = COMBINATION_INTERNAL_TEST

    # Load benchmark accuracies (one row per model)
    benchmark_path = os.path.join(predictions_dir, "test_prediction.csv")
    benchmark_df = pd.read_csv(benchmark_path)
    BENCHMARK_ACCURACY = dict(zip(benchmark_df['model_name'], benchmark_df['accuracy']))

    coverage_grid = np.round(np.arange(0.10, 1.00 + 1e-9, 0.001), 3)

    per_model_subset_rows = []  # detailed rows per subset
    summary_rows = []           # mean/std across subsets per model

    for model in PREDICTION_MODELS:
        print(f"\n== Model: {model}")
        benchmark = float(BENCHMARK_ACCURACY.get(model, np.nan))
        if not np.isfinite(benchmark):
            print(f"  WARNING: No benchmark for {model}; using 0.0")
            benchmark = 0.0

        # For each subset, we'll compute the envelope (max across embeddings) accuracy(c)
        subset_to_values = {s: [] for s in SUBSETS}  # will store list of accuracy arrays (per embedding)

        for emb_name, (best_n, _score) in zip(EMBEDDING_MODELS_ORDER, TABLE_DATA[model]):
            disk = EMBEDDING_DISK_NAME_MAP[emb_name]
            path = os.path.join(results_dir, "coverage_accuracy_best_N", f"{model}__{disk}", "coverage_accuracy_best_N.csv")
            if not os.path.exists(path):
                print(f"  Skipping: {path} (missing)")
                continue
            df = pd.read_csv(path)
            # Filter to best N and average numerics over L if necessary
            if 'N' in df.columns:
                df = df[df['N'] == best_n]
            # Group by L to reduce noise (averaging numerics)
            if 'L' in df.columns:
                num_cols = df.select_dtypes(include='number').columns
                df = df.groupby('L', as_index=False)[num_cols].mean()

            # Try subset-specific columns; if not available, fallback to global coverage/accuracy
            for subset in SUBSETS:
                cov_col, acc_col = _find_subset_columns(df, subset)
                if cov_col is None or acc_col is None:
                    # fallback
                    if 'coverage' in df.columns and 'accuracy' in df.columns:
                        cov_vals = df['coverage'].values
                        acc_vals = df['accuracy'].values
                    else:
                        # can't use this embedding for this subset
                        continue
                else:
                    cov_vals = df[cov_col].values
                    acc_vals = df[acc_col].values

                # Clean and build spline
                m = np.isfinite(cov_vals) & np.isfinite(acc_vals)
                cov_vals = cov_vals[m]
                acc_vals = acc_vals[m]
                if cov_vals.size < 2:
                    continue

                cov_vals = np.clip(cov_vals, 0.0, 1.0)
                acc_vals = np.clip(acc_vals, 0.0, 1.0)

                spline = build_spline(cov_vals, acc_vals)
                acc_on_grid = spline(coverage_grid)
                subset_to_values[subset].append(acc_on_grid)

        # Compute envelope per subset and integrate / normalize
        subset_norm_scores = {}
        for subset in SUBSETS:
            if len(subset_to_values[subset]) == 0:
                print(f"  WARNING: No curves for {model} / {subset}; setting score to NaN")
                subset_norm_scores[subset] = np.nan
                continue
            # Envelope = max over embeddings at each coverage
            A = np.vstack(subset_to_values[subset])  # [num_embeddings, num_grid]
            envelope = np.max(A, axis=0)
            area = float(np.trapz(envelope, coverage_grid))  # integral over [0.10,1.00]

            # Subtract benchmark * 0.9 and normalize by 0.9*(1-benchmark)
            adjusted = area - (benchmark * 0.90)
            denom = 0.90 * max(1e-12, (1.0 - benchmark))
            normalized = adjusted / denom
            subset_norm_scores[subset] = normalized

            per_model_subset_rows.append({
                'model': model,
                'subset': subset,
                'benchmark_accuracy': benchmark,
                'integral_area_0.10_1.00': area,
                'adjusted_area_minus_benchmarkx0.9': adjusted,
                'normalized_score': normalized
            })

        # Summary stats across subsets
        vals = [subset_norm_scores[s] for s in SUBSETS if np.isfinite(subset_norm_scores[s])]
        mean_val = float(np.mean(vals)) if len(vals) else np.nan
        std_val = float(np.std(vals)) if len(vals) else np.nan
        summary_rows.append({
            'model': model,
            'benchmark_accuracy': benchmark,
            'mean_normalized_integral_over_subsets': mean_val,
            'std_normalized_integral_over_subsets': std_val
        })
        print(f"  => mean={mean_val:.4f} | std={std_val:.4f}")

    # Save outputs
    output_dir = os.path.join(results_dir, combo_dir)
    os.makedirs(output_dir, exist_ok=True)

    detailed_path = os.path.join(output_dir, 'subset_integral_envelope_details.csv')
    summary_path = os.path.join(output_dir, 'subset_integral_envelope_summary.csv')

    pd.DataFrame(per_model_subset_rows).to_csv(detailed_path, index=False)
    pd.DataFrame(summary_rows).to_csv(summary_path, index=False)

    print("\nSaved:")
    print("  ", detailed_path)
    print("  ", summary_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Compute normalized integral of accuracy-vs-coverage envelopes (0.10..1.00) per model & subset."
    )
    parser.add_argument(
        'dataset_type', nargs='?', default='INTERNAL', choices=['INTERNAL', 'EXTERNAL'],
        help='Dataset type: INTERNAL or EXTERNAL (default: INTERNAL)'
    )
    args = parser.parse_args()
    run_integral_report(args.dataset_type)
