#!/usr/bin/env python3
"""
Plot MarkovScore versus Gaussian-noise variance (mean ± 95 % CI).

Reads every   m*/results/<ENV>/noised/noised_markov.csv   file produced by
your training runs.  For each environment it aggregates the MarkovScore
across runs at every (NoiseVariance, ObsDim) pair, then draws one line per
observation dimension with error-bar caps showing the CI on the mean.

Figure layout: two side-by-side panels (CartPole-v1, Acrobot-v1).

Usage
-----
    python plot_markov_vs_noise.py --runs_dir remote_results \
                                   --out fig_markov_vs_noise.png \
                                   --smooth 0
"""

from __future__ import annotations

import argparse
import math
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import t

###############################################################################
# ─────────────── helper functions ────────────────────────────────────────────
###############################################################################
def t95(n: np.ndarray | int | float) -> np.ndarray | float:
    """Two-sided 95 % critical value of Student-t with (n‒1) d.f."""
    df = np.maximum(np.asarray(n) - 1, 1)
    crit = t.ppf(0.975, df)
    return float(crit) if np.isscalar(n) else crit


def collect_csvs(root: Path, pattern: str) -> Dict[str, Path]:
    """Return {run_id: csv_path} for all files matching *pattern* inside m* dirs."""
    out: Dict[str, Path] = {}
    for run_dir in sorted(root.glob("m*")):
        p = run_dir / pattern
        if p.is_file():
            out[run_dir.name] = p
    return out


def rolling_mean(a: np.ndarray, w: int) -> np.ndarray:
    if w <= 1:
        return a
    cs = np.cumsum(np.insert(a, 0, 0.0))
    sm = (cs[w:] - cs[:-w]) / float(w)
    return np.concatenate([np.full(w - 1, np.nan), sm])


###############################################################################
# ─────────────── aggregation ────────────────────────────────────────────────
###############################################################################
def aggregate_markov(
    csv_paths: Dict[str, Path], smooth: int
) -> Dict[float, Dict[int, Tuple[np.ndarray, np.ndarray]]]:
    """
    Aggregate into  { variance : { obs_dim : (mean, ci) } }

    *mean* and *ci* are arrays aligned on the same variance grid (one element).
    """
    # (var, obs_dim) -> list[MarkovScore]  (one per run)
    buckets: Dict[Tuple[float, int], List[float]] = {}

    for run_id, path in csv_paths.items():
        df = pd.read_csv(path)
        for (var, dim), g in df.groupby(["NoiseVariance", "ObsDim"]):
            # every CSV already has a single row per (var,dim)
            buckets.setdefault((var, dim), []).extend(g["MarkovScore"].tolist())

    # turn into mean & CI
    agg: Dict[float, Dict[int, Tuple[np.ndarray, np.ndarray]]] = {}
    for (var, dim), scores in buckets.items():
        arr = np.asarray(scores, dtype=float)
        if smooth > 1:
            arr = rolling_mean(arr, smooth)
        m = np.nanmean(arr)
        n = np.sum(~np.isnan(arr))
        se = np.nanstd(arr, ddof=1) / math.sqrt(max(n, 1))
        ci = t95(n) * se
        agg.setdefault(var, {})[dim] = (m, ci)

    return agg


###############################################################################
# ─────────────── plotting ────────────────────────────────────────────────────
###############################################################################
OBS_NAMES = {
    "CartPole-v1": [
        "var_cart_position",
        "var_cart_velocity",
        "var_pole_angle",
        "var_pole_angular_velocity",
    ],
    "Acrobot-v1": [
        "var_cos(theta1)",
        "var_sin(theta1)",
        "var_cos(theta2)",
        "var_sin(theta2)",
        "var_angular_velocity1",
        "var_angular_velocity2",
    ],
}

COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]


def make_panel(ax, env: str, agg: Dict[float, Dict[int, Tuple[np.ndarray, np.ndarray]]]):
    # put variances on a sorted grid
    vars_sorted = sorted(agg.keys())
    x = np.asarray(vars_sorted)

    for dim, name in enumerate(OBS_NAMES[env]):
        means = []
        cis   = []
        for var in vars_sorted:
            if dim in agg[var]:
                m, ci = agg[var][dim]
            else:                     # if that (var,dim) missing in all runs
                m, ci = np.nan, np.nan
            means.append(m)
            cis.append(ci)

        means = np.asarray(means, float)
        cis   = np.asarray(cis,   float)

        ax.errorbar(
            x,
            means,
            yerr=cis,
            label=f"Dim {dim} ({name})",
            marker="o",
            capsize=3,
            linewidth=1.5,
        )

    ax.set_xscale("log")
    ax.set_xticks(x)
    ax.set_xticklabels([f"{v:g}" for v in x], rotation=20)
    ax.set_xlabel(r"Noise variance $\sigma^2$")
    ax.set_ylabel("MarkovScore")
    ax.set_title(env)
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=7, framealpha=0.9, ncol=1)


###############################################################################
# ─────────────── main ────────────────────────────────────────────────────────
###############################################################################
def main():
    parser = argparse.ArgumentParser("Plot MarkovScore vs. noise variance")
    parser.add_argument("--runs_dir", default="remote_results")
    parser.add_argument("--out",      default="fig_markov_vs_noise.png")
    parser.add_argument("--smooth",   type=int, default=0,
                        help="optional rolling-mean window (0 = off)")
    args = parser.parse_args()

    runs_root = Path(args.runs_dir).resolve()

    envs = ["CartPole-v1", "Acrobot-v1"]

    fig, axes = plt.subplots(
        1, len(envs), figsize=(len(envs) * 4.7, 3.2), sharey=False
    )

    if len(envs) == 1:
        axes = [axes]   # ensure iterable

    for ax, env in zip(axes, envs):
        csvs = collect_csvs(runs_root,
                            f"results/{env}/noised/noised_markov.csv")
        if not csvs:
            raise RuntimeError(f"No noised_markov.csv files found for {env}")

        agg = aggregate_markov(csvs, smooth=args.smooth)
        make_panel(ax, env, agg)

    plt.tight_layout()
    plt.savefig(args.out, dpi=300)
    print(f"Figure saved → {args.out}")


if __name__ == "__main__":
    main()