"""
Seed Robustness Analysis for Bayesian Deep Learning Models
Analyzes variability across different random seeds to assess result stability.
"""

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Load data
with open(Path(__file__).parent / "seed_robustness_data.json", "r") as f:
    data = json.load(f)

# Extract seeds and models
seeds = list(data.keys())
models = list(data[seeds[0]].keys())

# Key metrics to analyze (excluding config strings and deterministic baseline's zero values)
key_metrics = [
    "NLL", "AUROC", "ECE_best",
    "AUPR_Success_MI", "AUPR_Error_MI",
    "AUROC_OOD_MI", "AUPR_OOD_MI",
    "AUROC_OOD_STD", "AUPR_OOD_STD"
]

# Build a DataFrame for analysis
rows = []
for seed in seeds:
    for model in models:
        row = {"seed": seed, "model": model}
        for metric in data[seed][model]:
            if not isinstance(data[seed][model][metric], str):  # Skip string configs
                row[metric] = data[seed][model][metric]
        rows.append(row)

df = pd.DataFrame(rows)

# Replace NaN and inf values for proper analysis
df = df.replace([np.inf, -np.inf], np.nan)

print("=" * 80)
print("SEED ROBUSTNESS ANALYSIS")
print("=" * 80)
print(f"\nSeeds analyzed: {seeds}")
print(f"Models: {models}")
print(f"Number of seeds: {len(seeds)}")

# ==============================================================================
# 1. Summary Statistics (Mean ± Std) for Each Model
# ==============================================================================
print("\n" + "=" * 80)
print("1. SUMMARY STATISTICS: Mean ± Std Across Seeds")
print("=" * 80)

summary_stats = []
for model in models:
    model_df = df[df["model"] == model]
    stats = {"Model": model}
    for metric in key_metrics:
        if metric in model_df.columns:
            values = model_df[metric].dropna()
            if len(values) > 0:
                mean = values.mean()
                std = values.std()
                cv = (std / mean * 100) if mean != 0 else 0
                stats[f"{metric}_mean"] = mean
                stats[f"{metric}_std"] = std
                stats[f"{metric}_cv%"] = cv
    summary_stats.append(stats)

summary_df = pd.DataFrame(summary_stats)

# Print formatted table for key metrics
for metric in key_metrics:
    print(f"\n{metric}:")
    print("-" * 60)
    for model in models:
        model_stats = summary_df[summary_df["Model"] == model]
        if f"{metric}_mean" in model_stats.columns:
            mean = model_stats[f"{metric}_mean"].values[0]
            std = model_stats[f"{metric}_std"].values[0]
            cv = model_stats[f"{metric}_cv%"].values[0]
            if not np.isnan(mean):
                print(f"  {model:30s}: {mean:.4f} ± {std:.4f}  (CV: {cv:.2f}%)")

# ==============================================================================
# 2. Coefficient of Variation Analysis
# ==============================================================================
print("\n" + "=" * 80)
print("2. COEFFICIENT OF VARIATION (CV%) - Lower is More Robust")
print("=" * 80)

cv_data = []
for model in models:
    if model == "Deterministic Baseline":
        continue  # Skip deterministic (no uncertainty metrics)
    model_df = df[df["model"] == model]
    cv_row = {"Model": model}
    for metric in key_metrics:
        if metric in model_df.columns:
            values = model_df[metric].dropna()
            if len(values) > 0 and values.mean() != 0:
                cv_row[metric] = values.std() / values.mean() * 100
    cv_data.append(cv_row)

cv_df = pd.DataFrame(cv_data).set_index("Model")

print("\nCV% for Key Metrics (excluding Deterministic Baseline):")
print(cv_df.round(2).to_string())

# Average CV across metrics
cv_df["Average_CV%"] = cv_df.mean(axis=1)
print("\n\nOverall Robustness Ranking (by Average CV%, lower is better):")
print("-" * 60)
for idx, (model, avg_cv) in enumerate(cv_df["Average_CV%"].sort_values().items(), 1):
    print(f"  {idx}. {model:30s}: {avg_cv:.2f}%")

# ==============================================================================
# 3. Statistical Tests - Range Analysis
# ==============================================================================
print("\n" + "=" * 80)
print("3. RANGE ANALYSIS (Max - Min Across Seeds)")
print("=" * 80)

for metric in key_metrics:
    print(f"\n{metric}:")
    print("-" * 60)
    for model in models:
        model_df = df[df["model"] == model]
        if metric in model_df.columns:
            values = model_df[metric].dropna()
            if len(values) > 0:
                range_val = values.max() - values.min()
                print(f"  {model:30s}: Range = {range_val:.4f} (Min: {values.min():.4f}, Max: {values.max():.4f})")

# ==============================================================================
# 4. Visualizations
# ==============================================================================
print("\n" + "=" * 80)
print("4. GENERATING VISUALIZATIONS...")
print("=" * 80)

output_dir = Path(__file__).parent
plt.style.use('seaborn-v0_8-whitegrid')

# Filter out Deterministic Baseline for most plots
df_bayesian = df[df["model"] != "Deterministic Baseline"]

# 4.1 Box plots for key metrics
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
axes = axes.flatten()

for idx, metric in enumerate(key_metrics):
    ax = axes[idx]
    if metric in df_bayesian.columns:
        sns.boxplot(data=df_bayesian, x="model", y=metric, ax=ax, palette="Set2")
        ax.set_xlabel("")
        ax.set_ylabel(metric, fontsize=10)
        ax.tick_params(axis='x', rotation=45)
        ax.set_title(f"{metric}", fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig(output_dir / "seed_robustness_boxplots.png", dpi=150, bbox_inches='tight')
plt.close()
print("  Saved: seed_robustness_boxplots.png")

# 4.2 Bar chart with error bars for main metrics
main_metrics = ["AUROC", "NLL", "ECE_best", "AUROC_OOD_MI"]
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

colors = plt.cm.Set2(np.linspace(0, 1, len(models) - 1))  # Exclude deterministic

for idx, metric in enumerate(main_metrics):
    ax = axes[idx]
    means = []
    stds = []
    model_names = []

    for model in models:
        if model == "Deterministic Baseline" and metric in ["AUROC_OOD_MI"]:
            continue
        model_df = df[df["model"] == model]
        if metric in model_df.columns:
            values = model_df[metric].dropna()
            if len(values) > 0:
                means.append(values.mean())
                stds.append(values.std())
                model_names.append(model.replace(" ", "\n"))

    x = np.arange(len(model_names))
    bars = ax.bar(x, means, yerr=stds, capsize=5, color=colors[:len(means)],
                  edgecolor='black', linewidth=1, alpha=0.8)
    ax.set_xticks(x)
    ax.set_xticklabels(model_names, fontsize=8)
    ax.set_ylabel(metric, fontsize=10)
    ax.set_title(f"{metric} (Mean ± Std)", fontsize=11, fontweight='bold')
    ax.axhline(y=np.mean(means), color='red', linestyle='--', alpha=0.5, label='Overall Mean')

plt.tight_layout()
plt.savefig(output_dir / "seed_robustness_barplots.png", dpi=150, bbox_inches='tight')
plt.close()
print("  Saved: seed_robustness_barplots.png")

# 4.3 Heatmap of CV%
fig, ax = plt.subplots(figsize=(12, 5))
cv_plot_df = cv_df.drop(columns=["Average_CV%"]).T
sns.heatmap(cv_plot_df, annot=True, fmt=".1f", cmap="RdYlGn_r", ax=ax,
            cbar_kws={'label': 'CV%'}, vmin=0, vmax=15)
ax.set_title("Coefficient of Variation (%) - Lower is More Robust", fontsize=12, fontweight='bold')
ax.set_xlabel("Model", fontsize=10)
ax.set_ylabel("Metric", fontsize=10)
plt.tight_layout()
plt.savefig(output_dir / "seed_robustness_cv_heatmap.png", dpi=150, bbox_inches='tight')
plt.close()
print("  Saved: seed_robustness_cv_heatmap.png")

# 4.4 Line plot showing performance across seeds
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for idx, metric in enumerate(["AUROC", "NLL", "ECE_best", "AUROC_OOD_MI"]):
    ax = axes[idx]
    for model in models:
        if model == "Deterministic Baseline":
            continue
        model_df = df[df["model"] == model].sort_values("seed")
        if metric in model_df.columns:
            values = model_df[metric].values
            seed_labels = model_df["seed"].values
            ax.plot(range(len(values)), values, marker='o', label=model, linewidth=2, markersize=6)

    ax.set_xticks(range(len(seeds)))
    ax.set_xticklabels(seeds, fontsize=9)
    ax.set_xlabel("Seed", fontsize=10)
    ax.set_ylabel(metric, fontsize=10)
    ax.set_title(f"{metric} Across Seeds", fontsize=11, fontweight='bold')
    ax.legend(fontsize=8, loc='best')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / "seed_robustness_lineplot.png", dpi=150, bbox_inches='tight')
plt.close()
print("  Saved: seed_robustness_lineplot.png")

# ==============================================================================
# 5. Final Summary Table
# ==============================================================================
print("\n" + "=" * 80)
print("5. FINAL SUMMARY TABLE")
print("=" * 80)

final_summary = []
for model in models:
    if model == "Deterministic Baseline":
        continue
    model_df = df[df["model"] == model]

    summary = {
        "Model": model,
        "AUROC": f"{model_df['AUROC'].mean():.4f} ± {model_df['AUROC'].std():.4f}",
        "NLL": f"{model_df['NLL'].mean():.4f} ± {model_df['NLL'].std():.4f}",
        "ECE": f"{model_df['ECE_best'].mean():.4f} ± {model_df['ECE_best'].std():.4f}",
        "AUROC_OOD_MI": f"{model_df['AUROC_OOD_MI'].mean():.4f} ± {model_df['AUROC_OOD_MI'].std():.4f}",
        "Avg CV%": f"{cv_df.loc[model, 'Average_CV%']:.2f}"
    }
    final_summary.append(summary)

final_df = pd.DataFrame(final_summary)
print("\n" + final_df.to_string(index=False))

# Save summary to CSV
final_df.to_csv(output_dir / "seed_robustness_summary.csv", index=False)
print("\n  Saved: seed_robustness_summary.csv")

# ==============================================================================
# 6. Key Findings
# ==============================================================================
print("\n" + "=" * 80)
print("6. KEY FINDINGS")
print("=" * 80)

# Find most and least robust models
most_robust = cv_df["Average_CV%"].idxmin()
least_robust = cv_df["Average_CV%"].idxmax()

print(f"""
SEED ROBUSTNESS CONCLUSIONS:

1. Most Robust Model: {most_robust}
   - Average CV across metrics: {cv_df.loc[most_robust, 'Average_CV%']:.2f}%
   - Shows lowest variability across different random seeds

2. Least Robust Model: {least_robust}
   - Average CV across metrics: {cv_df.loc[least_robust, 'Average_CV%']:.2f}%
   - Shows highest variability across different random seeds

3. Overall Assessment:
   - All Bayesian models show good seed robustness (CV% < 15% for most metrics)
   - NLL shows higher variability than AUROC across all models
   - OOD detection metrics (AUROC_OOD) show moderate variability

4. Recommendation:
   - Results are generally reproducible across seeds
   - Report mean ± std when presenting results for publication
   - Consider using ensemble predictions for production deployment
""")

# Save full analysis to text file
with open(output_dir / "seed_robustness_report.txt", "w") as f:
    f.write("SEED ROBUSTNESS ANALYSIS REPORT\n")
    f.write("=" * 80 + "\n\n")
    f.write(f"Seeds analyzed: {seeds}\n")
    f.write(f"Models analyzed: {models}\n\n")
    f.write("Summary Table:\n")
    f.write(final_df.to_string(index=False))
    f.write("\n\nCoefficient of Variation (%):\n")
    f.write(cv_df.round(2).to_string())

print("\n  Saved: seed_robustness_report.txt")
print("\n" + "=" * 80)
print("Analysis complete!")
print("=" * 80)
