#!/usr/bin/env python3
"""
Rewards-vs-Noise plot, full σ² range 0 … 20  (CartPole-v1, Pendulum-v1).

For each (run, obs-dim, σ²) triple we take the *episode-average* reward,
then aggregate over runs: mean ± 95 % CI.

Baseline runs (no noise) are injected as σ² = 0.0 so the x-axis
starts at zero.
"""

from __future__ import annotations
import argparse
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import t


# ───────────────────────── helpers ─────────────────────────
def t95(n: int | np.ndarray) -> np.ndarray:
    df = np.maximum(np.asarray(n) - 1, 1)
    return t.ppf(0.975, df).astype(float)


def collect_csvs(root: Path, pattern: str) -> Dict[str, Path]:
    """Return {run_id → path} for all runs containing <pattern>."""
    out = {}
    for run in sorted(root.glob("m*")):
        p = run / pattern
        if p.is_file():
            out[run.name] = p
    return out


# ───────────────────── aggregate per environment ───────────
def aggregate_final_rewards(
    runs_root: Path,
    env: str,
    dims: List[int],
) -> Dict[float, Dict[int, Tuple[float, float]]]:
    """
    Returns: {σ² → {dim_id → (mean , half-CI)}}       (95 % two-sided CI)
    """
    # ── noise runs ──────────────────────────────────────────
    noise_csvs = collect_csvs(runs_root, f"results/{env}/noised/noised_rewards.csv")
    if not noise_csvs:
        raise RuntimeError(f"[{env}] no noised_rewards.csv found")

    records: List[Tuple[float, int, float]] = []  # (var , dim , run_mean)
    for rid, csv_path in noise_csvs.items():
        df = pd.read_csv(csv_path)
        df = df[df.Environment == env]
        for (var, dim), g in df.groupby(["NoiseVariance", "ObsDim"]):
            records.append((var, dim, g.Reward.mean()))

    # ── baseline (σ² = 0) ──────────────────────────────────
    base_csvs = collect_csvs(runs_root, f"results/{env}/csv/baseline_learning_curve.csv")
    if not base_csvs:
        print(f"[warning] [{env}] no baseline CSVs found (σ² = 0 will be missing)")
    else:
        for rid, csv_path in base_csvs.items():
            df = pd.read_csv(csv_path)
            df = df[df.Environment == env]
            run_mean = df.TotalReward.mean()
            for dim in dims:                       # copy to every dimension
                records.append((0.0, dim, run_mean))

    if not records:
        raise RuntimeError(f"[{env}] no data at all after parsing")

    per_run = pd.DataFrame(records, columns=["var", "dim", "run_mean"])

    # ── aggregate over runs ─────────────────────────────────
    out: Dict[float, Dict[int, Tuple[float, float]]] = {}
    for (var, dim), g in per_run.groupby(["var", "dim"]):
        vals = g.run_mean.to_numpy(float)
        mean = vals.mean()
        n    = len(vals)
        half = t95(n) * vals.std(ddof=1) / np.sqrt(max(n, 1))
        out.setdefault(var, {})[dim] = (mean, half)
    return out


# ───────────────────────── plotting ─────────────────────────
COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]


def plot_panel(
    ax: plt.Axes,
    env: str,
    agg: Dict[float, Dict[int, Tuple[float, float]]],
    obs_names: List[str],
):
    x_vals = sorted(agg.keys())                # sorted σ² values
    dims_present = sorted({d for dmap in agg.values() for d in dmap})

    for dim in dims_present:
        means = [agg[var][dim][0] for var in x_vals if dim in agg[var]]
        errs  = [agg[var][dim][1] for var in x_vals if dim in agg[var]]
        xs    = [v for v in x_vals if dim in agg[v]]

        c = COLORS[dim % len(COLORS)]
        ax.errorbar(xs, means, yerr=errs, fmt="-o", capsize=3, c=c,
                    label=f"Dim {dim} ({obs_names[dim]})")

    ax.set_title(f"Rewards vs Noise — {env}", fontsize=9)
    ax.set_xlabel("Noise variance  σ²")
    ax.set_ylabel("Mean final reward")
    ax.set_xlim(0, 20)
    ax.grid(alpha=0.3)
    ax.legend(fontsize=6)


# ──────────────────────────── CLI ──────────────────────────
def main() -> None:
    ap = argparse.ArgumentParser("Rewards-vs-Noise (σ² 0 … 20)")
    ap.add_argument("--runs_dir", default="remote_results")
    ap.add_argument("--out_file", default="rewards_vs_noise_full.png")
    args = ap.parse_args()

    runs_root = Path(args.runs_dir).resolve()

    OBS_NAMES = {
        "CartPole-v1": [
            "var_cart_position",
            "var_cart_velocity",
            "var_pole_angle",
            "var_pole_angular_velocity",
        ],
        "Pendulum-v1": [
            "var_cos(theta)",
            "var_sin(theta)",
            "var_angular_velocity",
        ],
    }

    fig, axes = plt.subplots(1, len(OBS_NAMES), figsize=(6 * len(OBS_NAMES), 3),
                             squeeze=False)

    for ax, (env, obs) in zip(axes[0], OBS_NAMES.items()):
        dims = list(range(len(obs)))           # 0…3 or 0…2
        agg  = aggregate_final_rewards(runs_root, env, dims)
        plot_panel(ax, env, agg, obs)

    plt.tight_layout()
    plt.savefig(args.out_file, dpi=300)
    print("Saved →", Path(args.out_file).resolve())


if __name__ == "__main__":
    main()