#!/usr/bin/env python3
"""
Aggregate Gaussian-noise learning curves – **Acrobot-v1**
σ² ∈ {0.02, 2.0} | 400 episodes | reward range −500 … 0
No data are discarded; the complete learning curves are plotted.
"""

from __future__ import annotations
import argparse
from pathlib import Path
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import t


# ─────────────── plot & aggregation parameters ────────────────
EPISODE_MAX  = 400            # x-axis limit
REWARD_RANGE = (-500, 0)      # y-axis limits


# ───────────────────── helper utilities ───────────────────────
def t95(n):
    """Two-sided 95 % t critical value for sample size(s) *n*."""
    df = np.maximum(np.asarray(n) - 1, 1)
    return t.ppf(0.975, df).astype(float)


def rolling_mean(a: np.ndarray, w: int):
    if w <= 1:
        return a
    cs = np.cumsum(np.insert(a, 0, 0.0))
    sm = (cs[w:] - cs[:-w]) / float(w)
    return np.concatenate([np.full(w - 1, np.nan), sm])


def collect_csvs(root: Path, pattern: str) -> Dict[str, Path]:
    """Return { run_id → csv_path } for every m*/<pattern> found."""
    out = {}
    for run in sorted(root.glob("m*")):
        p = run / pattern
        if p.is_file():
            out[run.name] = p
    return out


def _aggregate(run_dfs: List[pd.DataFrame], smooth: int) -> pd.DataFrame:
    """Align episodes across runs and compute mean ± 95 % CI."""
    merged = run_dfs[0]
    for df in run_dfs[1:]:
        merged = merged.merge(df, on="Episode", how="outer")

    data = merged.drop(columns="Episode").to_numpy(float)
    if data.shape[0] < data.shape[1]:          # episodes accidentally in cols
        data = data.T

    mean = np.nanmean(data, axis=1)
    n    = np.sum(~np.isnan(data), axis=1)
    se   = np.nanstd(data, axis=1, ddof=1) / np.sqrt(np.maximum(n, 1))
    ci   = t95(n) * se

    if smooth > 1:
        mean, ci = rolling_mean(mean, smooth), rolling_mean(ci, smooth)

    return pd.DataFrame({
        "Episode": merged["Episode"],
        "mean":   mean,
        "lower":  mean - ci,
        "upper":  mean + ci,
    })


# ─────────────────── per-source aggregation ───────────────────
def aggregate_baseline(root: Path, env: str, smooth: int) -> pd.DataFrame:
    csvs = collect_csvs(root, f"results/{env}/csv/baseline_learning_curve.csv")
    if not csvs:
        raise FileNotFoundError(f"No baseline CSVs for {env}")

    frames = []
    for run_id, p in csvs.items():
        df = pd.read_csv(p)
        df = df[df.Environment == env][["Episode", "TotalReward"]]
        frames.append(df.rename(columns={"TotalReward": run_id}))
    return _aggregate(frames, smooth)


def aggregate_noised(
    root: Path, env: str, smooth: int
) -> Dict[float, Dict[int, pd.DataFrame]]:
    csvs = collect_csvs(root, f"results/{env}/noised/noised_rewards.csv")
    if not csvs:
        raise FileNotFoundError(f"No noised_rewards.csv for {env}")

    buckets: Dict[tuple[float, int], List[pd.DataFrame]] = {}
    for run_id, p in csvs.items():
        df = pd.read_csv(p)
        df = df[df.Environment == env]
        for (var, dim), g in df.groupby(["NoiseVariance", "ObsDim"]):
            buckets.setdefault((var, dim), []).append(
                g[["Episode", "Reward"]].rename(columns={"Reward": run_id})
            )

    out: Dict[float, Dict[int, pd.DataFrame]] = {}
    for (var, dim), frames in buckets.items():
        out.setdefault(var, {})[dim] = _aggregate(frames, smooth)
    return out


# ────────────────────────── plotting ──────────────────────────
COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]

def plot_panel(env, var, dim_map, base, obs, out_path):
    plt.figure(figsize=(4, 3))

    plt.plot(base["Episode"], base["mean"], c="black", lw=2.2, label="Baseline")
    plt.fill_between(base["Episode"], base["lower"], base["upper"],
                     color="black", alpha=0.14)

    for i, (dim, df) in enumerate(sorted(dim_map.items())):
        col = COLORS[i % len(COLORS)]
        plt.plot(df["Episode"], df["mean"],  c=col,
                 label=f"Dim {dim} ({obs[dim]})")
        plt.fill_between(df["Episode"], df["lower"], df["upper"],
                         color=col, alpha=0.20)

    plt.title(rf"{env}, $\sigma^2={var}$  (mean ± 95 % CI)")
    plt.xlabel("Episode");        plt.ylabel("Reward")
    plt.xlim(0, EPISODE_MAX);     plt.ylim(REWARD_RANGE)
    plt.legend(fontsize=5, framealpha=0.9)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=300)
    plt.close()
    print("Saved →", out_path)


# ──────────────────────────── CLI driver ───────────────────────────
def main():
    ap = argparse.ArgumentParser("Aggregate Gaussian-noise plots (Acrobot-v1)")
    ap.add_argument("--runs_dir",   default="remote_results")
    ap.add_argument("--output_dir", default="aggregated_plots")
    ap.add_argument("--smooth",     type=int, default=10,
                    help="rolling-window length (0 = none)")
    args = ap.parse_args()

    ENV  = "Acrobot-v1"
    VARS = {0.02, 2.0}

    OBS = [
        "var_cos(theta1)",
        "var_sin(theta1)",
        "var_cos(theta2)",
        "var_sin(theta2)",
        "var_angular_velocity1",
        "var_angular_velocity2",
    ]

    root = Path(args.runs_dir).resolve()
    out  = Path(args.output_dir).resolve()

    try:
        base  = aggregate_baseline(root, ENV, args.smooth)
        noise = aggregate_noised  (root, ENV, args.smooth)
    except FileNotFoundError as e:
        print(e)
        return

    for var in sorted(noise):
        if var not in VARS:
            continue
        plot_panel(
            ENV, var,
            noise[var],
            base,
            OBS,
            out / ENV / f"{ENV}_var_{var}_mean_CI.png",
        )

    print("All done.")


if __name__ == "__main__":
    main()