# Code/experiments/run_ablation_and_baselines.py
# Code/experiments/run_ablation_and_baselines.py

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# imports (no more Code.datasets problem)
from Code.experiments.synthetic_experiments import run_synthetic_experiment
from Code.experiments.sota_comparison import run_maml_baseline, run_cvae_augmentation
from Code.utils.training_utils import set_seed


"""
Runs ablations over module count M and scaling exponent s.
Compares AWML vs naive augmentation vs CVAE vs MAML baseline.
Saves results to CSV and plots automatically.
"""

# experiment grid
module_counts = [2, 4, 8]
scaling_exponents = [0.5, 1.0, 1.5]
Neff_values = [1, 5, 20, 100]  # effective sample sizes

out_dir = "artifacts/ablations"
os.makedirs(out_dir, exist_ok=True)

set_seed(0)

rows = []
for M in module_counts:
    for s in scaling_exponents:
        for Neff in Neff_values:
            # Run AWML synthetic experiment
            awml_metrics = run_synthetic_experiment(M=M, scaling_exponent=s, Neff=Neff, method="AWML")
            # Run naive augmentation
            naive_metrics = run_synthetic_experiment(M=M, scaling_exponent=s, Neff=Neff, method="naive")
            # Run CVAE augmentation
            cvae_metrics = run_cvae_augmentation(M=M, scaling_exponent=s, Neff=Neff)
            # Run MAML-like few-shot baseline (ignore Neff here)
            maml_metrics = run_maml_baseline(M=M)

            rows.append({
                "M": M,
                "s": s,
                "Neff": Neff,
                "AWML_RMSE": awml_metrics["rmse"],
                "Naive_RMSE": naive_metrics["rmse"],
                "CVAE_RMSE": cvae_metrics["rmse"],
                "MAML_RMSE": maml_metrics["rmse"]
            })

df = pd.DataFrame(rows)
csv_path = os.path.join(out_dir, "ablations_and_baselines.csv")
df.to_csv(csv_path, index=False)
print(f"Saved results to {csv_path}")

# Plotting example: RMSE vs Neff for AWML vs baselines (aggregated)
for M in module_counts:
    for s in scaling_exponents:
        subset = df[(df["M"] == M) & (df["s"] == s)]
        plt.figure()
        plt.plot(subset["Neff"], subset["AWML_RMSE"], 'o-', label="AWML")
        plt.plot(subset["Neff"], subset["Naive_RMSE"], 'x--', label="Naive")
        plt.plot(subset["Neff"], subset["CVAE_RMSE"], 's-.', label="CVAE")
        plt.hlines(subset["MAML_RMSE"].iloc[0], xmin=subset["Neff"].min(), xmax=subset["Neff"].max(), colors='k', linestyles=':', label="MAML")
        plt.xscale('log')
        plt.xlabel("Effective sample size Neff (log)")
        plt.ylabel("RMSE")
        plt.title(f"Ablation M={M}, s={s}")
        plt.legend()
        fig_path = os.path.join(out_dir, f"rmse_M{M}_s{s}.png")
        plt.savefig(fig_path, dpi=150)
        plt.close()
        print(f"Saved {fig_path}")

print("All ablations and baseline comparisons done.")