#!/usr/bin/env python3
"""
Aggregate learning curves when **dropping a single observation dimension**.

Three panels will be generated (CartPole-v1, Pendulum-v1, Acrobot-v1).
Each coloured line = mean ± 95 % CI over runs for *one* dropped dimension.
The noise-free baseline is shown in black.

The script autodetects the dropped-dimension column name in
results/<env>/dropped/dropped_rewards.csv – it accepts any of:
    ObsDim, DroppedDim, DroppedObsDim, Dim
"""

from __future__ import annotations
import argparse
from pathlib import Path
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import t


# ───────────────────── basic helpers ─────────────────────
def t95(n) -> np.ndarray:
    """two-sided 95 % t critical value(s)."""
    df = np.maximum(np.asarray(n) - 1, 1)
    return t.ppf(0.975, df).astype(float)


def rolling_mean(a: np.ndarray, w: int) -> np.ndarray:
    if w <= 1:
        return a
    cs = np.cumsum(np.insert(a, 0, 0.0))
    sm = (cs[w:] - cs[:-w]) / w
    return np.concatenate([np.full(w - 1, np.nan), sm])


def collect_csvs(root: Path, pattern: str) -> Dict[str, Path]:
    out = {}
    for run in sorted(root.glob("m*")):
        p = run / pattern
        if p.is_file():
            out[run.name] = p
    return out


# ─────────────────── aggregation helpers ─────────────────
def aggregate_baseline(root: Path, env: str, smooth: int) -> pd.DataFrame:
    csvs = collect_csvs(root, f"results/{env}/csv/baseline_learning_curve.csv")
    if not csvs:
        raise FileNotFoundError(f"No baseline CSVs for {env}")

    dfs: List[pd.DataFrame] = []
    for run_id, p in csvs.items():
        df = (pd.read_csv(p)
                .loc[lambda d: d.Environment == env, ["Episode", "TotalReward"]]
                .rename(columns={"TotalReward": run_id}))
        dfs.append(df)

    merged = dfs[0]
    for df in dfs[1:]:
        merged = merged.merge(df, on="Episode", how="outer")

    data = merged.drop(columns="Episode").to_numpy(float)
    # ── ensure episodes are rows ──
    if data.shape[0] < data.shape[1]:
        data = data.T
        episodes = merged.columns[1:].astype(int)
    else:
        episodes = merged["Episode"].to_numpy()

    mean = np.nanmean(data, axis=1)
    n    = np.sum(~np.isnan(data), axis=1)
    half = t95(n) * np.nanstd(data, axis=1, ddof=1) / np.sqrt(np.maximum(n, 1))

    if smooth > 1:
        mean, half = rolling_mean(mean, smooth), rolling_mean(half, smooth)

    # use a brand-new episode vector that always matches mean's length
    episodes = np.arange(1, len(mean) + 1)

    return pd.DataFrame({
        "Episode": episodes,
        "mean":    mean,
        "lower":   mean - half,
        "upper":   mean + half,
    })


def _detect_dim_col(df: pd.DataFrame) -> str:
    for cand in ("ObsDim", "DroppedDim", "DroppedObsDim", "Dim"):
        if cand in df.columns:
            return cand
    raise KeyError("No dropped-dimension column found "
                   f"in CSV header: {list(df.columns)}")


def aggregate_dropped(root: Path, env: str, smooth: int
        ) -> Dict[int, pd.DataFrame]:
    """
    Returns {dim_id → dataframe(Episode, mean, lower, upper)}.
    """
    csvs = collect_csvs(root, f"results/{env}/dropped/dropped_rewards.csv")
    if not csvs:
        raise FileNotFoundError(f"No dropped_rewards.csv for {env}")

    buckets: Dict[int, List[pd.DataFrame]] = {}

    for run_id, p in csvs.items():
        df = pd.read_csv(p)

        # optional Environment column
        if "Environment" in df.columns:
            df = df[df.Environment == env]

        dim_col = _detect_dim_col(df)

        for dim, g in df.groupby(dim_col):
            buckets.setdefault(dim, []).append(
                g[["Episode", "Reward"]].rename(columns={"Reward": run_id})
            )

    out: Dict[int, pd.DataFrame] = {}
    for dim, lst in buckets.items():
        merged = lst[0]
        for df in lst[1:]:
            merged = merged.merge(df, on="Episode", how="outer")

        data = merged.drop(columns="Episode").to_numpy(float)
        # ── ensure episodes are rows ──
        if data.shape[0] < data.shape[1]:
            data = data.T
            if len(merged["Episode"]) == data.shape[0]:
                episodes = merged["Episode"].to_numpy()
            else:
                episodes = np.arange(1, data.shape[0]+1)
        else:
            episodes = merged["Episode"].to_numpy()

        mean = np.nanmean(data, axis=1)
        n    = np.sum(~np.isnan(data), axis=1)
        half = t95(n) * np.nanstd(data, axis=1, ddof=1) / np.sqrt(np.maximum(n, 1))

        if smooth > 1:
            mean, half = rolling_mean(mean, smooth), rolling_mean(half, smooth)

        # use a brand-new episode vector that always matches mean's length
        episodes = np.arange(1, len(mean) + 1)

        out[dim] = pd.DataFrame({
            "Episode": episodes,
            "mean":    mean,
            "lower":   mean - half,
            "upper":   mean + half,
        })
    return out


# ─────────────────────── plotting ────────────────────────
COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]


def plot_panel(env: str,
               base: pd.DataFrame,
               dropped: Dict[int, pd.DataFrame],
               obs_names: List[str],
               xmax: int,
               yrange: tuple[float, float],
               out_path: Path):
    plt.figure(figsize=(4, 3))

    # baseline
    plt.plot(base["Episode"], base["mean"], c="black", lw=2.5, label="Baseline")
    plt.fill_between(base["Episode"], base["lower"], base["upper"],
                     color="black", alpha=0.15)

    # dropped-dimension curves
    for i, dim in enumerate(sorted(dropped)):
        df = dropped[dim]
        c  = COLORS[i % len(COLORS)]
        obs_label = obs_names[dim] if dim < len(obs_names) else f"Dim {dim}"
        plt.plot(df["Episode"], df["mean"], c=c,
                 label=f"Dropped {obs_label} (idx={dim})")
        plt.fill_between(df["Episode"], df["lower"], df["upper"], color=c, alpha=0.20)

    plt.title(f"{env}: Dropped-Dimension Learning Curves (mean ±95 % CI)",
              fontsize=6)
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.xlim(0, xmax)
    plt.ylim(*yrange)
    plt.legend(fontsize=3)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=300)
    plt.close()
    print("Saved →", out_path)


# ────────────────────────── CLI ──────────────────────────
def main() -> None:
    ap = argparse.ArgumentParser("Aggregate dropped-dimension learning curves")
    ap.add_argument("--runs_dir",   default="remote_results")
    ap.add_argument("--output_dir", default="aggregated_plots")
    ap.add_argument("--smooth",     type=int, default=10,
                    help="rolling-window length (episodes)")
    # axis limits (tweak as needed)
    ap.add_argument("--cart_xmax",  type=int,   default=550)
    ap.add_argument("--pend_xmax",  type=int,   default=2300)
    ap.add_argument("--acro_xmax",  type=int,   default=400)
    ap.add_argument("--cart_ylim",  type=float, nargs=2, default=[0, 550])
    ap.add_argument("--pend_ylim",  type=float, nargs=2, default=[-1500, -100])
    ap.add_argument("--acro_ylim",  type=float, nargs=2, default=[-500, 0])
    args = ap.parse_args()

    LIMITS = {
        "CartPole-v1": dict(xmax=args.cart_xmax, yrange=tuple(args.cart_ylim)),
        "Pendulum-v1": dict(xmax=args.pend_xmax, yrange=tuple(args.pend_ylim)),
        "Acrobot-v1":  dict(xmax=args.acro_xmax, yrange=tuple(args.acro_ylim)),
    }

    OBS = {
        "CartPole-v1": [
            "var_cart_position",
            "var_cart_velocity",
            "var_pole_angle",
            "var_pole_angular_velocity",
        ],
        "Pendulum-v1": [
            "var_cos(theta)",
            "var_sin(theta)",
            "var_angular_velocity",
        ],
        "Acrobot-v1": [
            "var_cos(theta1)",
            "var_sin(theta1)",
            "var_cos(theta2)",
            "var_sin(theta2)",
            "var_angular_velocity1",
            "var_angular_velocity2",
        ],
    }

    root = Path(args.runs_dir).resolve()
    out  = Path(args.output_dir).resolve()

    for env, obs_names in OBS.items():
        try:
            base    = aggregate_baseline(root, env, smooth=args.smooth)
            dropped = aggregate_dropped(root, env, smooth=args.smooth)
        except FileNotFoundError as e:
            print(e)
            continue

        cfg = LIMITS[env]
        plot_panel(env, base, dropped, obs_names,
                   xmax=cfg["xmax"], yrange=cfg["yrange"],
                   out_path=out / env / f"{env}_dropped_mean_CI.png")

    print("All done.")


if __name__ == "__main__":
    main()