import os
import json
import pandas as pd
import numpy as np
from scipy.interpolate import UnivariateSpline
from scipy.integrate import quad

import matplotlib.pyplot as plt

from config.config import PREDICTIONS_TRAIN, RESULTS_TRAIN

# ===========================
# Step 1: Setup
# ===========================


# Load benchmark accuracies from CSV
benchmark_path = os.path.join(PREDICTIONS_TRAIN,"summary.csv")
benchmark_df = pd.read_csv(benchmark_path)
BENCHMARK_ACCURACY = dict(zip(benchmark_df['model_name'], benchmark_df['accuracy']))

# Models
EMBEDDING_MODELS = {
    "mobilenet_v2": "mobilenetv2_100",
    "dinov1_s": "vit_small_patch16_224_dino",
    "dinov1_b": "vit_base_patch16_224_dino",
    "dinov1_b8": "vit_base_patch8_224_dino",
    "dinov2_s": "vit_small_patch14_dinov2.lvd142m",
    "dinov2_b": "vit_base_patch14_dinov2.lvd142m"
}
PREDICTION_MODELS = list(BENCHMARK_ACCURACY.keys())

# Output
output_dir = os.path.join(RESULTS_TRAIN, "integrals")
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "integrals_by_model_and_embedding.csv")

# ===========================
# Step 2: Compute Integrals
# ===========================

all_results = []

for pred_model in PREDICTION_MODELS:
    for emb_key, emb_model in EMBEDDING_MODELS.items():
        file_path = os.path.join(
            RESULTS_TRAIN,
            "coverage_accuracy",
            f"{pred_model}__{emb_key}/coverage_accuracy_sweep.csv"
        )

        if not os.path.exists(file_path):
            print(f"⚠️ File not found: {file_path}")
            continue

        df = pd.read_csv(file_path)
        df = df.astype({'N': int, 'L': float, 'coverage': float, 'accuracy': float})

        for n_val in df['N'].unique():
            group = df[df['N'] == n_val]
            x = group['coverage'].values
            y = group['accuracy'].values

            sorted_idx = np.argsort(x)
            x = x[sorted_idx]
            y = y[sorted_idx]

            x_unique, idx_unique = np.unique(x, return_index=True)
            y_unique = y[idx_unique]

            if len(x_unique) < 4:
                print(f"⚠️ Not enough usable x-points for N = {n_val} ({pred_model} | {emb_key})")
                all_results.append({
                    'model': pred_model,
                    'emb_model': emb_key,
                    'N': n_val,
                    'integral': np.nan,
                    'benchmark': BENCHMARK_ACCURACY[pred_model],
                    'adjusted_integral': np.nan,
                    'normalized_adjusted': np.nan
                })
                continue

            # Backup original data for plotting
            x_original = x_unique.copy()
            y_original = y_unique.copy()

            # Linear extrapolation if needed
            if x_unique[-1] < 1.0:
                x_tail = x_unique[-15:] if len(x_unique) >= 15 else x_unique
                y_tail = y_unique[-len(x_tail):]

                try:
                    coeffs = np.polyfit(x_tail, y_tail, 1)
                    a, b = coeffs

                    extrap_x = np.linspace(x_unique[-1] + 1e-6, 1.0, 10)
                    extrap_y = a * extrap_x + b

                    x_unique = np.concatenate([x_unique, extrap_x])
                    y_unique = np.concatenate([y_unique, extrap_y])
                except Exception as e:
                    print(f"⚠️ Extrapolation failed for N = {n_val} ({pred_model} | {emb_key}): {e}")

            try:
                spline = UnivariateSpline(x_unique, y_unique, s=1e-4, ext='extrapolate')
                integral, _ = quad(spline, 0.1, 1.0)

                benchmark = BENCHMARK_ACCURACY[pred_model]
                adjusted = integral - (0.9 * benchmark)
                denominator = (1.0 - benchmark) * 0.9
                normalized = adjusted / denominator if denominator > 0 else np.nan
                """
                # === Plotting ===
                xx = np.linspace(0.1, 1.0, 300)
                yy = spline(xx)
            
                plt.figure(figsize=(6, 4))
                plt.plot(xx, yy, label='Spline curve')
                plt.scatter(x_original, y_original, color='blue', label='Original data')
                if x_unique[-1] > x_original[-1]:
                    plt.scatter(x_unique[len(x_original):], y_unique[len(y_original):], color='orange', label='Extrapolated')
                plt.axvline(0.3, linestyle='--', color='gray', label='0.3 cutoff')
                plt.title(f"{pred_model} | {emb_key} | N={n_val}")
                plt.xlabel("Coverage")
                plt.ylabel("Accuracy")
                plt.legend()
                plt.tight_layout()
                plt.show()
                """
                # Save result
                all_results.append({
                    'model': pred_model,
                    'emb_model': emb_key,
                    'N': n_val,
                    'integral': integral,
                    'benchmark': benchmark,
                    'adjusted_integral': adjusted,
                    'normalized_adjusted': normalized
                })
            except Exception as e:
                print(f"❌ Error for N = {n_val} ({pred_model} | {emb_key}): {e}")
                all_results.append({
                    'model': pred_model,
                    'emb_model': emb_key,
                    'N': n_val,
                    'integral': np.nan,
                    'benchmark': BENCHMARK_ACCURACY[pred_model],
                    'adjusted_integral': np.nan,
                    'normalized_adjusted': np.nan
                })

print("\n🎉 Done with all model combinations!")

# ===========================
# Step 3: Save Results
# ===========================

df_integrals = pd.DataFrame(all_results)
df_integrals.to_csv(output_path, index=False)
print(f"✅ Final results saved to: {output_path}")
