import os
import json
import argparse
import pandas as pd
import numpy as np
from scipy.interpolate import UnivariateSpline
from scipy.integrate import quad
import matplotlib.pyplot as plt
from config.config import (
    BASE_DIR,
    RESULTS_TEST_INTERNAL,
    RESULTS_TEST_EXTERNAL,
    PREDICTIONS_TEST_INTERNAL,
    PREDICTIONS_TEST_EXTERNAL,
)

def main(dataset_type: str):
    # ============================
    # Select dataset paths
    # ============================
    if dataset_type == "EXTERNAL":
        RESULTS_DIR = RESULTS_TEST_EXTERNAL
        PREDICTIONS_DIR = PREDICTIONS_TEST_EXTERNAL
    else:  # INTERNAL (default)
        RESULTS_DIR = RESULTS_TEST_INTERNAL
        PREDICTIONS_DIR = PREDICTIONS_TEST_INTERNAL

    INPUT_DIR = os.path.join(RESULTS_DIR, "coverage_accuracy_best_N")
    OUTPUT_DIR = RESULTS_DIR
    PLOT_DIR = os.path.join(OUTPUT_DIR, "spline_plots")
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(PLOT_DIR, exist_ok=True)
    OUTPUT_FILE = os.path.join(OUTPUT_DIR, "normalized_integrals_by_subset_variation.csv")

    # ============================
    # Load benchmark accuracies
    # ============================
    benchmark_path = os.path.join(PREDICTIONS_DIR, "test_prediction.csv")
    benchmark_df = pd.read_csv(benchmark_path)
    BENCHMARK_ACCURACY = dict(zip(benchmark_df['model_name'], benchmark_df['accuracy']))

    EMBEDDING_MODELS = {
        "mobilenet_v2": "mobilenetv2_100",
        "dinov1_s": "vit_small_patch16_224_dino",
        "dinov1_b": "vit_base_patch16_224_dino",
        "dinov1_b8": "vit_base_patch8_224_dino",
        "dinov2_s": "vit_small_patch14_dinov2.lvd142m",
        "dinov2_b": "vit_base_patch14_dinov2.lvd142m"
    }

    PREDICTION_MODELS = list(BENCHMARK_ACCURACY.keys())

    # ============================
    # Processing
    # ============================
    summary = []

    for pred_model in PREDICTION_MODELS:
        for emb_key in EMBEDDING_MODELS:
            model_folder = f"{pred_model}__{emb_key}"
            file_path = os.path.join(INPUT_DIR, model_folder, "coverage_accuracy_best_N.csv")

            if not os.path.exists(file_path):
                print(f" File not found: {file_path}")
                continue

            df = pd.read_csv(file_path)
            if 'subset' not in df.columns:
                df['subset'] = 'A'

            df = df.astype({'N': int, 'L': float, 'coverage': float, 'accuracy': float, 'subset': str})
            subset_norms = {}

            for subset in df['subset'].unique():
                sub_df = df[df['subset'] == subset]

                if len(sub_df) < 4:
                    continue

                x = sub_df['coverage'].values
                y = sub_df['accuracy'].values
                sorted_idx = np.argsort(x)
                x = x[sorted_idx]
                y = y[sorted_idx]

                x_unique, idx_unique = np.unique(x, return_index=True)
                y_unique = y[idx_unique]

                if len(x_unique) < 4:
                    continue

                # Extrapolate to 1.0 if needed
                if x_unique[-1] < 1.0:
                    x_tail = x_unique[-15:] if len(x_unique) >= 15 else x_unique
                    y_tail = y_unique[-len(x_tail):]

                    try:
                        a, b = np.polyfit(x_tail, y_tail, 1)
                        extrap_x = np.linspace(x_unique[-1] + 1e-6, 1.0, 10)
                        extrap_y = a * extrap_x + b

                        x_unique = np.concatenate([x_unique, extrap_x])
                        y_unique = np.concatenate([y_unique, extrap_y])
                    except Exception as e:
                        print(f" Extrapolation failed ({pred_model} | {emb_key} | {subset}): {e}")
                        continue

                try:
                    spline = UnivariateSpline(x_unique, y_unique, s=1e-4, ext='extrapolate')
                    integral, _ = quad(spline, 0.1, 1.0)

                    benchmark = BENCHMARK_ACCURACY.get(pred_model, None)
                    if benchmark is None:
                        normalized = np.nan
                    else:
                        adjusted = integral - 0.9 * benchmark
                        denom = (1 - benchmark) * 0.9
                        normalized = adjusted / denom if denom > 0 else np.nan

                    subset_norms[subset] = normalized
                    print("Normalized", normalized)

                    # Plot the curve
                    plt.figure(figsize=(6, 4))
                    x_plot = np.linspace(0.0, 1.0, 500)
                    y_plot = spline(x_plot)
                    plt.plot(x_unique, y_unique, 'o', label='Data + extrapolation')
                    plt.plot(x_plot, y_plot, '-', label='Spline')
                    plt.fill_between(x_plot, y_plot, alpha=0.1, label='Area under curve')
                    plt.title(f"{pred_model} | {emb_key} | Subset {subset}")
                    plt.xlabel("Coverage")
                    plt.ylabel("Accuracy")
                    plt.legend()
                    plt.grid(True)
                    plot_path = os.path.join(PLOT_DIR, f"{pred_model}__{emb_key}__{subset}.png")
                    plt.savefig(plot_path)
                    plt.close()

                except Exception as e:
                    print(f" Error fitting spline ({pred_model} | {emb_key} | {subset}): {e}")
                    continue

            if len(subset_norms) >= 2:
                values = list(subset_norms.values())
                print(values)
                summary.append({
                    "prediction_model": pred_model,
                    "embedding_model": emb_key,
                    "mean_normalized_integral": np.mean(values),
                    "std_normalized_integral": np.std(values, ddof=1),
                    "A": subset_norms.get('A', np.nan),
                    "B": subset_norms.get('B', np.nan),
                    "C": subset_norms.get('C', np.nan)
                })

    # ============================
    # Save Results
    # ============================
    df_summary = pd.DataFrame(summary)
    df_summary.to_csv(OUTPUT_FILE, index=False)
    print(f" Saved to: {OUTPUT_FILE}")
    print(f" Plots saved to: {PLOT_DIR}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Compute normalized integrals by subset variation for INTERNAL/EXTERNAL datasets."
    )
    parser.add_argument(
        "dataset_type",
        nargs="?",
        default="INTERNAL",
        choices=["INTERNAL", "EXTERNAL"],
        help="Dataset type: INTERNAL or EXTERNAL (default: INTERNAL)",
    )
    args = parser.parse_args()
    main(args.dataset_type)