"""
Add Interval Score to existing uncertainty quantification results.
"""

import numpy as np
import pandas as pd
import importlib

# Reload evaluation module to get the new compute_interval_score function
import modules.evaluation
importlib.reload(modules.evaluation)
from modules.evaluation import compute_interval_score

print("="*80)
print("COMPUTING INTERVAL SCORE FOR ALL MODELS")
print("="*80)

# Assumes uncertainty_results dict exists with these keys from the notebook
# Each entry should have: 'lower_bound', 'upper_bound' arrays

# Compute Interval Score for each model
for name in uncertainty_results.keys():
    print(f"\nComputing Interval Score for {name}...")

    # Get bounds and true values
    lower = uncertainty_results[name]['lower_bound']
    upper = uncertainty_results[name]['upper_bound']

    # Compute Interval Score
    interval_score = compute_interval_score(
        y_test_original.flatten(),
        lower.flatten(),
        upper.flatten(),
        confidence_level=CONFIDENCE_LEVEL
    )

    # Add to results
    uncertainty_results[name]['Interval_Score'] = interval_score

    print(f"  Interval Score: {interval_score:.3f} (LOWER IS BETTER)")

# Create updated summary table
print("\n" + "="*80)
print("UPDATED UNCERTAINTY QUANTIFICATION RESULTS")
print("="*80)

df_uq_updated = pd.DataFrame([{
    'Model': k,
    'NLL': v['NLL'],
    'CRPS': v['CRPS'],
    'ECE': v['ECE'],
    'PICP': v['PICP'],
    'MPIW': v['MPIW'],
    'Interval_Score': v['Interval_Score']
} for k, v in uncertainty_results.items()])

# Sort by Interval Score (best = lowest)
df_uq_updated = df_uq_updated.sort_values('Interval_Score')

print(df_uq_updated.to_string(index=False))

# Save updated results
df_uq_updated.to_csv('results_csv/uncertainty_quantification_results_with_IS.csv', index=False)
print("\n" + "="*80)
print("Saved: results_csv/uncertainty_quantification_results_with_IS.csv")
print("="*80)

# Print ranking by Interval Score
print("\n" + "="*80)
print("MODEL RANKING BY INTERVAL SCORE (Lower = Better)")
print("="*80)
for i, (idx, row) in enumerate(df_uq_updated.iterrows(), 1):
    print(f"{i}. {row['Model']:30s} IS={row['Interval_Score']:8.3f}  "
          f"PICP={row['PICP']:.3f}  MPIW={row['MPIW']:.1f}")
