"""
Generate optimized radar plot for ICML paper - LSTM experiments
Metrics: MAE, ECE, PICP, AUROC-OOD, AUPR-OOD (drops NLL, uses MAE instead of R²)
"""

import numpy as np
import matplotlib.pyplot as plt

# Results from notebook
results_data = {
    'Full-Rank BBB': {
        'MAE': 10.52, 'ECE': 0.165, 'PICP': 0.721,
        'AUROC_OOD': 0.485, 'AUPR_OOD': 0.743
    },
    'Low-Rank': {
        'MAE': 10.48, 'ECE': 0.156, 'PICP': 0.723,
        'AUROC_OOD': 0.663, 'AUPR_OOD': 0.798
    },
    'Low-Rank (SVD)': {
        'MAE': 10.47, 'ECE': 0.205, 'PICP': 0.673,
        'AUROC_OOD': 0.735, 'AUPR_OOD': 0.837
    },
    'Rank-1 Mult.': {
        'MAE': 10.80, 'ECE': 0.320, 'PICP': 0.428,
        'AUROC_OOD': 0.632, 'AUPR_OOD': 0.790
    },
    'Deep Ensemble': {
        'MAE': 10.40, 'ECE': 0.323, 'PICP': 0.295,
        'AUROC_OOD': 0.698, 'AUPR_OOD': 0.866
    },
}

# Metric directions (True = higher is better)
metrics = ['MAE', 'ECE', 'PICP', 'AUROC_OOD', 'AUPR_OOD']
higher_is_better = {
    'MAE': False, 'ECE': False, 'PICP': True,
    'AUROC_OOD': True, 'AUPR_OOD': True
}

# Colors - emphasize low-rank models
colors = {
    'Full-Rank BBB': '#ff7f0e',
    'Low-Rank': '#2ca02c',
    'Low-Rank (SVD)': '#1f77b4',  # Blue for best low-rank
    'Rank-1 Mult.': '#9467bd',
    'Deep Ensemble': '#d62728',
}

def normalize_metrics(values, higher_is_better):
    """Normalize to [0,1] where 1 is best."""
    values = np.array(values, dtype=float)
    vmin, vmax = np.min(values), np.max(values)
    if np.isclose(vmax, vmin):
        return np.ones_like(values) * 0.5
    if higher_is_better:
        return (values - vmin) / (vmax - vmin)
    return (vmax - values) / (vmax - vmin)

# Build normalized values
model_names = list(results_data.keys())
normalized = {m: [] for m in model_names}

for metric in metrics:
    vals = [results_data[m][metric] for m in model_names]
    norm = normalize_metrics(vals, higher_is_better[metric])
    for m, v in zip(model_names, norm):
        normalized[m].append(v)

# Create radar plot
fig = plt.figure(figsize=(7, 6))
ax = fig.add_subplot(111, polar=True)

# Labels with better formatting
labels = ['MAE\n(accuracy)', 'ECE\n(calibration)', 'PICP\n(coverage)',
          'AUROC\n(OOD)', 'AUPR\n(OOD)']

num_vars = len(labels)
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
angles += angles[:1]

ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
ax.set_thetagrids(np.degrees(angles[:-1]), labels, fontsize=10)
ax.set_ylim(0, 1)

# Plot each model
for name in model_names:
    vals = normalized[name] + [normalized[name][0]]
    ax.plot(angles, vals, linewidth=2.5 if 'Low-Rank' in name else 1.5,
            label=name, color=colors[name],
            linestyle='-' if 'Low-Rank' in name else '--',
            alpha=1.0 if 'Low-Rank' in name else 0.7)
    ax.fill(angles, vals, alpha=0.15 if 'Low-Rank' in name else 0.05, color=colors[name])

ax.set_title('LSTM Model Comparison\n(outer = better)', y=1.12, fontsize=12, fontweight='bold')
ax.legend(loc='upper right', bbox_to_anchor=(1.35, 1.05), fontsize=9)

# Add gridlines
ax.set_rticks([0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'], fontsize=8)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('paper_radar_lstm.png', dpi=300, bbox_inches='tight')
plt.savefig('paper_radar_lstm.pdf', bbox_inches='tight')
print("Saved: paper_radar_lstm.png and paper_radar_lstm.pdf")
plt.show()
