#!/usr/bin/env python3
"""
aggregate_gaussian_plots.py  – CartPole-v1 only
  • σ² ∈ {0.02 (0-400 ep), 1.0 (0-600 ep)}
  • rewards > 300 are dropped
  • legend font size cut in half
"""

from __future__ import annotations
import argparse, numpy as np, pandas as pd, matplotlib.pyplot as plt
from pathlib  import Path
from typing   import Dict, List
from scipy.stats import t

# ─── CONFIG ──────────────────────────────────────────────────────────────
DROP_ABOVE   = 300.0                # clip rewards above this
VAR_EPISODES = {0.02: 400, 1.0: 600}
YLIM         = (0, 300)
LEGEND_FSIZE = 3                    # ←── smaller legend
COLORS       = plt.rcParams["axes.prop_cycle"].by_key()["color"]

# ─── HELPERS ─────────────────────────────────────────────────────────────
def t95(n): return t.ppf(0.975, np.maximum(n-1, 1))
def roll(a,w):
    if w<=1: return a
    cs=np.cumsum(np.insert(a,0,0.0)); sm=(cs[w:]-cs[:-w])/w
    return np.concatenate([np.full(w-1,np.nan), sm])
def collect(root, patt): return {r.name:r/patt for r in root.glob("m*") if (r/patt).is_file()}

# ─── AGGREGATION ─────────────────────────────────────────────────────────
def _agg(frames:List[pd.DataFrame], smooth:int):
    m=frames[0]
    for f in frames[1:]: m=m.merge(f,on="Episode",how="outer")
    dat=m.drop(columns="Episode").to_numpy(float)
    mean=np.nanmean(dat,1); n=np.sum(~np.isnan(dat),1)
    se=np.nanstd(dat,1,ddof=1)/np.sqrt(np.maximum(n,1)); ci=t95(n)*se
    if smooth>1: mean,ci=roll(mean,smooth),roll(ci,smooth)
    return pd.DataFrame(dict(Episode=m["Episode"],
                             mean=mean, lower=mean-ci, upper=mean+ci))

def agg_base(root, env, smooth):
    csvs=collect(root,f"results/{env}/csv/baseline_learning_curve.csv")
    if not csvs: raise FileNotFoundError("baseline CSVs missing")
    frames=[]
    for rid,p in csvs.items():
        df=pd.read_csv(p); df=df[df.Environment==env][["Episode","TotalReward"]]
        df.loc[df.TotalReward>DROP_ABOVE,"TotalReward"]=np.nan
        frames.append(df.rename(columns={"TotalReward":rid}))
    return _agg(frames,smooth)

def agg_noise(root, env, smooth):
    csvs=collect(root,f"results/{env}/noised/noised_rewards.csv")
    if not csvs: raise FileNotFoundError("noised_rewards.csv missing")
    buck:Dict[tuple[float,int],List[pd.DataFrame]]={}
    for rid,p in csvs.items():
        df=pd.read_csv(p); df=df[df.Environment==env]
        df.loc[df.Reward>DROP_ABOVE,"Reward"]=np.nan
        for (v,d),g in df.groupby(["NoiseVariance","ObsDim"]):
            buck.setdefault((v,d),[]).append(
                g[["Episode","Reward"]].rename(columns={"Reward":rid}))
    out:Dict[float,Dict[int,pd.DataFrame]]={}
    for (v,d),lst in buck.items(): out.setdefault(v,{})[d]=_agg(lst,smooth)
    return out

# ─── PLOTTING ────────────────────────────────────────────────────────────
def plot_panel(env,var,dmap,base,obs,out):
    ep_max=VAR_EPISODES[var]
    b=base.loc[base.Episode<=ep_max]
    plt.figure(figsize=(4,3))
    plt.plot(b["Episode"], b["mean"],  c="black", lw=2.5, label="Baseline")
    plt.fill_between(      b["Episode"], b["lower"], b["upper"],
                           color="black", alpha=0.10)

    for i,(d,df) in enumerate(sorted(dmap.items())):
        df=df.loc[df.Episode<=ep_max]
        c=COLORS[i%len(COLORS)]
        plt.plot(df["Episode"], df["mean"],  c=c, label=f"Dim={d} ({obs[d]})")
        plt.fill_between(   df["Episode"], df["lower"], df["upper"],
                            color=c, alpha=0.2)

    plt.title(rf"{env}, $\sigma^2={var}$  (mean ± 95% CI)")
    plt.xlabel("Episode"); plt.ylabel("Reward")
    plt.xlim(0,ep_max);    plt.ylim(*YLIM)
    plt.legend(fontsize=LEGEND_FSIZE)     # ←── smaller legend
    plt.grid(True,alpha=.3); plt.tight_layout()
    out.parent.mkdir(parents=True,exist_ok=True); plt.savefig(out,dpi=300)
    plt.close(); print("Saved →",out)

# ─── MAIN ────────────────────────────────────────────────────────────────
def main():
    ar=argparse.ArgumentParser()
    ar.add_argument("--runs_dir",default="remote_results")
    ar.add_argument("--output_dir",default="aggregated_plots")
    ar.add_argument("--smooth",type=int,default=10)
    args=ar.parse_args(); root,out=Path(args.runs_dir).resolve(),Path(args.output_dir).resolve()

    obs=["var_cart_position","var_cart_velocity","var_pole_angle","var_pole_angular_velocity"]
    try: base=agg_base(root,"CartPole-v1",args.smooth); noise=agg_noise(root,"CartPole-v1",args.smooth)
    except FileNotFoundError as e: print("✘",e); return

    for var in (0.02,1.0):
        if var in noise:
            plot_panel("CartPole-v1",var,noise[var],base,obs,
                       out/"CartPole-v1"/f"CartPole-v1_var_{var}_mean_CI.png")
    print("All done.")

if __name__=="__main__": main()