#!/usr/bin/env python3
"""
Aggregate Gaussian-noise learning curves – Pendulum-v1
σ² ∈ {1.0, 2.0} ,  3000 episodes, reward −1500 … −100
"""

from __future__ import annotations
import argparse
from pathlib import Path
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import t


# ─────────────────────────────────────────────────────────────────────────────
# Utilities
# ─────────────────────────────────────────────────────────────────────────────
def t95(n):
    """two-sided 95 % Student-t critical value(s)."""
    df = np.maximum(np.asarray(n) - 1, 1)
    crit = t.ppf(0.975, df)
    return float(crit) if np.isscalar(n) else crit


def rolling_mean(arr: np.ndarray, window: int):
    if window <= 1:
        return arr
    cs = np.cumsum(np.insert(arr, 0, 0.0))
    sm = (cs[window:] - cs[:-window]) / float(window)
    return np.concatenate([np.full(window - 1, np.nan), sm])


def collect_csvs(root: Path, pattern: str) -> Dict[str, Path]:
    """m*/<pattern> → mapping run_id → path."""
    out: Dict[str, Path] = {}
    for run in sorted(root.glob("m*")):
        p = run / pattern
        if p.is_file():
            out[run.name] = p
    return out


# ─────────────────────────────────────────────────────────────────────────────
# Aggregation
# ─────────────────────────────────────────────────────────────────────────────
def aggregate_baseline(runs_root: Path, env: str, smooth: int) -> pd.DataFrame:
    csvs = collect_csvs(runs_root, f"results/{env}/csv/baseline_learning_curve.csv")
    if not csvs:
        raise FileNotFoundError(f"No baseline CSVs for {env}")

    frames: List[pd.DataFrame] = []
    for run_id, p in csvs.items():
        df = pd.read_csv(p)
        df = df[df.Environment == env][["Episode", "TotalReward"]]
        df.rename(columns={"TotalReward": run_id}, inplace=True)
        frames.append(df)

    merged = frames[0]
    for df in frames[1:]:
        merged = merged.merge(df, on="Episode", how="outer")

    data = merged.drop(columns="Episode").to_numpy(float)
    mean = np.nanmean(data, axis=1)
    n = np.sum(~np.isnan(data), axis=1)
    se = np.nanstd(data, axis=1, ddof=1) / np.sqrt(np.maximum(n, 1))
    ci = t95(n) * se
    if smooth > 1:
        mean, ci = rolling_mean(mean, smooth), rolling_mean(ci, smooth)

    return pd.DataFrame(
        dict(Episode=merged["Episode"], mean=mean, lower=mean - ci, upper=mean + ci)
    )


def aggregate_noised(
    runs_root: Path, env: str, smooth: int
) -> Dict[float, Dict[int, pd.DataFrame]]:
    csvs = collect_csvs(runs_root, f"results/{env}/noised/noised_rewards.csv")
    if not csvs:
        raise FileNotFoundError(f"No noised_rewards.csv for {env}")

    buckets: Dict[tuple[float, int], List[pd.DataFrame]] = {}
    for run_id, p in csvs.items():
        df = pd.read_csv(p)
        df = df[df.Environment == env]
        for (var, dim), g in df.groupby(["NoiseVariance", "ObsDim"]):
            sub = g[["Episode", "Reward"]].rename(columns={"Reward": run_id})
            buckets.setdefault((var, dim), []).append(sub)

    out: Dict[float, Dict[int, pd.DataFrame]] = {}
    for (var, dim), lst in buckets.items():
        merged = lst[0]
        for df in lst[1:]:
            merged = merged.merge(df, on="Episode", how="outer")

        data = merged.drop(columns="Episode").to_numpy(float)
        mean = np.nanmean(data, axis=1)
        n = np.sum(~np.isnan(data), axis=1)
        se = np.nanstd(data, axis=1, ddof=1) / np.sqrt(np.maximum(n, 1))
        ci = t95(n) * se
        if smooth > 1:
            mean, ci = rolling_mean(mean, smooth), rolling_mean(ci, smooth)

        out.setdefault(var, {})[dim] = pd.DataFrame(
            dict(Episode=merged["Episode"], mean=mean, lower=mean - ci, upper=mean + ci)
        )
    return out


# ─────────────────────────────────────────────────────────────────────────────
# Plotting
# ─────────────────────────────────────────────────────────────────────────────
COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]


def plot_panel(env, var, dim_map, base, obs_names, out_path):
    plt.figure(figsize=(4, 3))

    # baseline
    plt.plot(base["Episode"], base["mean"], c="black", lw=2.5, label="Baseline")
    plt.fill_between(base["Episode"], base["lower"], base["upper"],
                     color="black", alpha=0.15)

    # noisy dims
    for i, (d, df) in enumerate(sorted(dim_map.items())):
        c = COLORS[i % len(COLORS)]
        plt.plot(df["Episode"], df["mean"], c=c,
                 label=f"Dim={d} ({obs_names[d]})")
        plt.fill_between(df["Episode"], df["lower"], df["upper"],
                         color=c, alpha=0.20)

    plt.title(rf"{env}, $\sigma^2={var}$  (mean ± 95% CI)")
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.xlim(0, 5000)
    plt.ylim(-1700, 0)
    plt.legend(fontsize=5)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=300)
    plt.close()
    print("Saved →", out_path)


# ─────────────────────────────────────────────────────────────────────────────
# CLI driver
# ─────────────────────────────────────────────────────────────────────────────
def main():
    ap = argparse.ArgumentParser("Aggregate Gaussian-noise plots (Pendulum-v1)")
    ap.add_argument("--runs_dir",   default="remote_results")
    ap.add_argument("--output_dir", default="aggregated_plots")
    ap.add_argument("--smooth",     default=10, type=int,
                    help="rolling-window length (0 = no smoothing)")
    args = ap.parse_args()

    ENV  = "Pendulum-v1"
    VARS = {1.0, 2.0}

    OBS_NAMES = [
        "var_cos(theta)",
        "var_sin(theta)",
        "var_angular_velocity",
    ]

    runs_root = Path(args.runs_dir).resolve()
    out_root  = Path(args.output_dir).resolve()

    try:
        baseline_df = aggregate_baseline(runs_root, ENV, smooth=args.smooth)
        noised_dict = aggregate_noised (runs_root, ENV, smooth=args.smooth)
    except FileNotFoundError as e:
        print(e)
        return

    for var in sorted(noised_dict.keys()):
        if var not in VARS:
            continue
        plot_panel(
            ENV, var,
            noised_dict[var],
            baseline_df,
            OBS_NAMES,
            out_root / ENV / f"{ENV}_var_{var}_mean_CI.png"
        )

    print("All done.")


if __name__ == "__main__":
    main()