import argparse, time, csv, os, random, math, sys, glob
import numpy as np
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

try:
    from scipy.stats import ttest_ind
except Exception:
    ttest_ind = None



def parse_args():
    p = argparse.ArgumentParser("EPSRO (URR + deterministic meta strategy optimization) for Kuhn — multi-seed")

    p.add_argument("--iters", type=int, default=40)
    p.add_argument("--kmax", type=int, default=41)

    p.add_argument("--urr_steps", type=int, default=200)
    p.add_argument("--theta_lr", type=float, default=5e-4)
    p.add_argument("--beta_lr", type=float, default=0.1)
    p.add_argument("--beta_sched", choices=["const","sqrt","harmonic"], default="harmonic")
    p.add_argument("--theta_sched", choices=["const","sqrt","harmonic"], default="harmonic")

    p.add_argument("--beta_kl", type=float, default=0.0)

    p.add_argument("--beta_temp", type=float, default=1.0)
    p.add_argument("--entropy_coef", type=float, default=0.0, help="Optional entropy regularizer for σ during β updates.")

    p.add_argument("--tau", type=float, default=1.0)
    p.add_argument("--init_std", type=float, default=1.0)
    p.add_argument("--noise_std", type=float, default=0.05, help="After fixing a BR into Π^r, re-init active θ from it + noise")

    p.add_argument("--clip_grad", type=float, default=1.0)
    p.add_argument("--logit_cap", type=float, default=50.0)
    p.add_argument("--prob_eps", type=float, default=1e-6)

    p.add_argument("--pipeline_levels", type=int, default=3)
    p.add_argument("--plateau_window", type=int, default=10)
    p.add_argument("--plateau_eps", type=float, default=1e-4)
    p.add_argument("--plateau_std_eps", type=float, default=1e-4)

    p.add_argument("--outdir", type=str, default="results")
    p.add_argument("--csv_base", type=str, default="epsro_kuhn")
    p.add_argument("--seeds", type=str, default="0,1,2,3,4")
    p.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto")
    p.add_argument("--no_plots", action="store_true")

    p.add_argument("--ttest_against_glob", type=str, default=None,
                   help='Glob for baseline seed CSVs (same iters), e.g. "results/psro_kuhn_seed*.csv"')

    return p.parse_args()



def _seed_everything(s: int):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

def _pick_device(argdev: str):
    if argdev == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(argdev)

def _mem_mb():
    try:
        import psutil
        return psutil.Process().memory_info().rss / (1024**2), "rss"
    except Exception:
        try:
            import resource
            if sys.platform == "darwin":
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2), "ru_maxrss"
            else:
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0, "ru_maxrss"
        except Exception:
            return float("nan"), "n/a"

def _softmax_np(x, cap=50.0):
    x = np.clip(x, -cap, cap)
    z = x - np.max(x)
    e = np.exp(z)
    return e / (e.sum() + 1e-12)

def _eta_t(eta0: float, t: int, sched: str):
    if sched == "const": return eta0
    if sched == "sqrt": return eta0 / max(1.0, math.sqrt(t))
    if sched == "harmonic": return eta0 / (1.0 + 0.5*t)
    return eta0

def bernoulli_kl(p, q, eps=1e-6):
    p = torch.clamp(p, eps, 1.0-eps)
    q = torch.clamp(q, eps, 1.0-eps)
    return p * torch.log(p/q) + (1.0-p) * torch.log((1.0-p)/(1.0-q))



CARDS = [0, 1, 2]

def ev_p1_vs_numpy(p1_vec6, p2_vec6):
    b1, c1 = p1_vec6[:3], p1_vec6[3:]
    c2, b2 = p2_vec6[:3], p2_vec6[3:]
    ev = 0.0
    for c1i in CARDS:
        for c2i in CARDS:
            if c1i == c2i: continue
            s4 = 2.0 if c1i > c2i else -2.0
            s2 = 1.0 if c1i > c2i else -1.0
            term_B = b1[c1i] * (c2[c2i]*s4 + (1.0-c2[c2i])*1.0)
            term_C = (1.0-b1[c1i]) * ( b2[c2i]*( c1[c1i]*s4 + (1.0-c1[c1i])*(-1.0) ) + (1.0-b2[c2i])*s2 )
            ev += term_B + term_C
    return ev / 6.0

def ev_p1_vs_torch(p1_probs6, p2_probs6):
    b1, c1 = p1_probs6[:3], p1_probs6[3:]
    c2, b2 = p2_probs6[:3], p2_probs6[3:]
    ev = torch.zeros([], dtype=torch.float32, device=p1_probs6.device)
    for c1i in CARDS:
        for c2i in CARDS:
            if c1i == c2i: continue
            s4 = 2.0 if c1i > c2i else -2.0
            s2 = 1.0 if c1i > c2i else -1.0
            term_B = b1[c1i] * (c2[c2i]*s4 + (1.0-c2[c2i])*1.0)
            term_C = (1.0-b1[c1i]) * ( b2[c2i]*( c1[c1i]*s4 + (1.0-c1[c1i])*(-1.0) ) + (1.0-b2[c2i])*s2 )
            ev = ev + term_B + term_C
    return ev / 6.0

def _enumerate_pure_policies_p1():
    out = []
    for mask in range(64):
        b = [(mask>>0)&1, (mask>>1)&1, (mask>>2)&1]
        c = [(mask>>3)&1, (mask>>4)&1, (mask>>5)&1]
        out.append(np.array(b+c, dtype=np.float64))
    return out

def _enumerate_pure_policies_p2():
    out = []
    for mask in range(64):
        c = [(mask>>0)&1, (mask>>1)&1, (mask>>2)&1]
        b = [(mask>>3)&1, (mask>>4)&1, (mask>>5)&1]
        out.append(np.array(c+b, dtype=np.float64))
    return out

PURE1 = _enumerate_pure_policies_p1()
PURE2 = _enumerate_pure_policies_p2()



class BernoulliPolicy(nn.Module):
    def __init__(self, init_std=1.0):
        super().__init__()
        self.logits = nn.Parameter(torch.zeros(6, dtype=torch.float32))
        nn.init.normal_(self.logits, mean=0.0, std=init_std)

    def probs(self, tau: float, eps: float):
        p = torch.sigmoid(self.logits / max(1e-8, tau))
        p = torch.clamp(torch.nan_to_num(p, nan=0.5), eps, 1.0-eps)
        return p


class EPSRORunner:
    def __init__(self, args, seed: int, device: torch.device):
        self.args = args
        self.seed = seed
        self.device = device
        self._setup_state()

    def _setup_state(self):
        _seed_everything(self.seed)

        self.max_k = int(self.args.kmax)
        self.J = int(getattr(self.args, "pipeline_levels", 3))

        self.F1 = []
        self.F2 = []

        init1 = np.clip(np.random.rand(6), 1e-3, 1-1e-3).astype(np.float64)
        init2 = np.clip(np.random.rand(6), 1e-3, 1-1e-3).astype(np.float64)
        self.F1.append(init1)
        self.F2.append(init2)

        self.A1 = [BernoulliPolicy(init_std=self.args.init_std).to(self.device) for _ in range(self.J)]
        self.A2 = [BernoulliPolicy(init_std=self.args.init_std).to(self.device) for _ in range(self.J)]
        self.optA1 = [torch.optim.Adam(p.parameters(), lr=self.args.theta_lr) for p in self.A1]
        self.optA2 = [torch.optim.Adam(p.parameters(), lr=self.args.theta_lr) for p in self.A2]

        for j in range(self.J):
            self._random_init_policy(self.A1[j])
            self._random_init_policy(self.A2[j])

        self.low_hist_p1 = []
        self.low_hist_p2 = []


    def _sigma_from_beta(self, beta: torch.Tensor):
        temp = float(self.args.beta_temp)
        b = beta / max(1e-8, temp)
        b = torch.clamp(b, -float(self.args.logit_cap), float(self.args.logit_cap))
        return torch.softmax(b, dim=0)

    def _mix_from_set(self, Pi_list, sigma_t: torch.Tensor):
        Pi = torch.tensor(np.stack(Pi_list, 0), dtype=torch.float32, device=self.device)
        mix = (sigma_t[:, None] * Pi).sum(0)
        eps = float(self.args.prob_eps)
        return torch.clamp(torch.nan_to_num(mix, nan=0.5), eps, 1.0 - eps)

    def _random_init_policy(self, pol: torch.nn.Module):
        with torch.no_grad():
            pol.logits.copy_(float(self.args.init_std) * torch.randn_like(pol.logits))

    @torch.no_grad()
    def _snapshot_active(self, pol: torch.nn.Module):
        return pol.probs(self.args.tau, self.args.prob_eps).detach().cpu().numpy().astype(np.float64)

    def _opp_restricted_set(self, fixed_list, active_list, j: int):
        out = list(fixed_list)
        for t in range(j):
            out.append(self._snapshot_active(active_list[t]))
        return out


    def _solve_urr_p1(self, theta: torch.nn.Module, opt: torch.optim.Optimizer, opp_set: list[np.ndarray]):
        theta.train()
        prev_p = theta.probs(self.args.tau, self.args.prob_eps).detach()

        beta = torch.zeros(len(opp_set), device=self.device, dtype=torch.float32, requires_grad=True)

        for t in range(1, int(self.args.urr_steps) + 1):
            lr_theta = _eta_t(self.args.theta_lr, t, self.args.theta_sched)
            lr_beta  = _eta_t(self.args.beta_lr,  t, self.args.beta_sched)
            for g in opt.param_groups:
                g["lr"] = lr_theta

            sigma2 = self._sigma_from_beta(beta)
            mix2 = self._mix_from_set(opp_set, sigma2).detach()

            p1 = theta.probs(self.args.tau, self.args.prob_eps)
            u1 = ev_p1_vs_torch(p1, mix2)
            loss_theta = -u1

            if float(self.args.beta_kl) > 0.0:
                kl = bernoulli_kl(p1, prev_p, eps=float(self.args.prob_eps)).mean()
                loss_theta = loss_theta + float(self.args.beta_kl) * kl

            opt.zero_grad(set_to_none=True)
            loss_theta.backward()
            if self.args.clip_grad and self.args.clip_grad > 0:
                torch.nn.utils.clip_grad_norm_(theta.parameters(), float(self.args.clip_grad))
            opt.step()

            with torch.no_grad():
                p1_now = theta.probs(self.args.tau, self.args.prob_eps).detach()
                u1_k = []
                for pi2 in opp_set:
                    p2k = torch.tensor(pi2, dtype=torch.float32, device=self.device)
                    u1_k.append(ev_p1_vs_torch(p1_now, p2k).detach())
                u1_k = torch.stack(u1_k, 0) 
                u2_k = -u1_k

            sigma2_new = self._sigma_from_beta(beta)
            u2 = (sigma2_new * u2_k).sum()
            ent = -(sigma2_new * torch.log(sigma2_new + 1e-12)).sum()
            obj = u2 + float(self.args.entropy_coef) * ent

            grad = torch.autograd.grad(obj, beta, retain_graph=False, create_graph=False)[0]
            with torch.no_grad():
                beta += lr_beta * grad

        theta.eval()
        with torch.no_grad():
            sigma2 = self._sigma_from_beta(beta).detach().cpu().numpy().astype(np.float64)
            mix2 = self._mix_from_set(opp_set, torch.tensor(sigma2, device=self.device)).detach()
            val = float(ev_p1_vs_torch(theta.probs(self.args.tau, self.args.prob_eps), mix2).detach().cpu())
        return sigma2, val

    def _solve_urr_p2(self, theta: torch.nn.Module, opt: torch.optim.Optimizer, opp_set: list[np.ndarray]):
        theta.train()
        prev_p = theta.probs(self.args.tau, self.args.prob_eps).detach()

        beta = torch.zeros(len(opp_set), device=self.device, dtype=torch.float32, requires_grad=True)

        for t in range(1, int(self.args.urr_steps) + 1):
            lr_theta = _eta_t(self.args.theta_lr, t, self.args.theta_sched)
            lr_beta  = _eta_t(self.args.beta_lr,  t, self.args.beta_sched)
            for g in opt.param_groups:
                g["lr"] = lr_theta

            sigma1 = self._sigma_from_beta(beta)
            mix1 = self._mix_from_set(opp_set, sigma1).detach()

            p2 = theta.probs(self.args.tau, self.args.prob_eps)
            u1 = ev_p1_vs_torch(mix1, p2)
            u2 = -u1
            loss_theta = -u2

            if float(self.args.beta_kl) > 0.0:
                kl = bernoulli_kl(p2, prev_p, eps=float(self.args.prob_eps)).mean()
                loss_theta = loss_theta + float(self.args.beta_kl) * kl

            opt.zero_grad(set_to_none=True)
            loss_theta.backward()
            if self.args.clip_grad and self.args.clip_grad > 0:
                torch.nn.utils.clip_grad_norm_(theta.parameters(), float(self.args.clip_grad))
            opt.step()

            with torch.no_grad():
                p2_now = theta.probs(self.args.tau, self.args.prob_eps).detach()
                u1_k = []
                for pi1 in opp_set:
                    p1k = torch.tensor(pi1, dtype=torch.float32, device=self.device)
                    u1_k.append(ev_p1_vs_torch(p1k, p2_now).detach())
                u1_k = torch.stack(u1_k, 0)

            sigma1_new = self._sigma_from_beta(beta)
            u1exp = (sigma1_new * u1_k).sum()
            ent = -(sigma1_new * torch.log(sigma1_new + 1e-12)).sum()
            obj = u1exp + float(self.args.entropy_coef) * ent

            grad = torch.autograd.grad(obj, beta, retain_graph=False, create_graph=False)[0]
            with torch.no_grad():
                beta += lr_beta * grad

        theta.eval()
        with torch.no_grad():
            sigma1 = self._sigma_from_beta(beta).detach().cpu().numpy().astype(np.float64)
            mix1 = self._mix_from_set(opp_set, torch.tensor(sigma1, device=self.device)).detach()
            u1 = float(ev_p1_vs_torch(mix1, theta.probs(self.args.tau, self.args.prob_eps)).detach().cpu())
            u2 = -u1
        return sigma1, u2


    def _plateau(self, hist: list[float]):
        W = int(getattr(self.args, "plateau_window", 10))
        eps = float(getattr(self.args, "plateau_eps", 1e-4))
        std_eps = float(getattr(self.args, "plateau_std_eps", 1e-4))

        if len(hist) < W:
            return False
        win = np.array(hist[-W:], dtype=np.float64)
        imp = float(np.max(win) - np.min(win))
        sd = float(np.std(win))
        return (imp < eps) and (sd < std_eps)


    @torch.no_grad()
    def current_meta(self):
        P1 = self.F1 + [self._snapshot_active(p) for p in self.A1]
        P2 = self.F2 + [self._snapshot_active(p) for p in self.A2]

        K1, K2 = len(P1), len(P2)
        M = np.zeros((K1, K2), dtype=np.float64)
        for i in range(K1):
            for j in range(K2):
                M[i, j] = ev_p1_vs_numpy(P1[i], P2[j])

        s1 = np.ones(K1, dtype=np.float64) / K1
        s2 = np.ones(K2, dtype=np.float64) / K2
        val = float(s1 @ M @ s2)
        return s1, s2, val
    

    @staticmethod
    def _solve_nash_zerosum_mwu(M: np.ndarray, steps: int = 4000, eta: float = 0.5):
        K1, K2 = M.shape
        x = np.ones(K1, dtype=np.float64) / K1
        y = np.ones(K2, dtype=np.float64) / K2

        scale = np.max(np.abs(M)) + 1e-12
        A = M / scale

        for t in range(1, steps + 1):
            g1 = A @ y
            g2 = -(x @ A)

            x *= np.exp(eta * g1)
            y *= np.exp(eta * g2)

            x_sum = x.sum()
            y_sum = y.sum()
            if x_sum <= 0 or not np.isfinite(x_sum):
                x = np.ones(K1, dtype=np.float64) / K1
            else:
                x /= x_sum
            if y_sum <= 0 or not np.isfinite(y_sum):
                y = np.ones(K2, dtype=np.float64) / K2
            else:
                y /= y_sum

        return x, y

    @torch.no_grad()
    def _meta_sets(self):
        P1 = list(self.F1) + [self._snapshot_active(p) for p in self.A1]
        P2 = list(self.F2) + [self._snapshot_active(p) for p in self.A2]
        return P1, P2

    @torch.no_grad()
    def _meta_payoff_matrix(self, P1, P2):
        K1, K2 = len(P1), len(P2)
        M = np.zeros((K1, K2), dtype=np.float64)
        for i in range(K1):
            for j in range(K2):
                M[i, j] = ev_p1_vs_numpy(P1[i], P2[j])
        return M

    @torch.no_grad()
    def _nashconv_on_meta(self):
        P1, P2 = self._meta_sets()
        M = self._meta_payoff_matrix(P1, P2)

        x, y = self._solve_nash_zerosum_mwu(
            M,
            steps=max(2000, int(self.args.urr_steps) * 10),
            eta=0.5
        )
        v = float(x @ M @ y)

        mix1 = np.sum(np.stack(P1, 0) * x[:, None], axis=0)
        mix2 = np.sum(np.stack(P2, 0) * y[:, None], axis=0)

        br1 = max(ev_p1_vs_numpy(pi, mix2) for pi in PURE1)
        br2min = min(ev_p1_vs_numpy(mix1, pj) for pj in PURE2)

        nashconv = max(0.0, br1 - v, v - br2min)
        return nashconv, v, len(P1), len(P2)



    def run(self, csv_path: str, header=True):
        os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)
        if header:
            print(f"[EPSRO] seed={self.seed} | device={self.device} | J={self.J} | urr_steps={self.args.urr_steps}")

        hist_nc, hist_val, hist_dt, hist_mem = [], [], [], []
        hist_n1, hist_n2 = [], []
        mem_type_rec = "n/a"

        with open(csv_path, "w", newline="") as fcsv:
            w = csv.writer(fcsv)

            w.writerow(["iter",
                        "dt_mean","dt_std",
                        "mem_mb_mean","mem_mb_std",
                        "mix_ev_mean","mix_ev_std",
                        "nashconv_mean","nashconv_std",
                        "n_strats_p1_mean","n_strats_p1_std",
                        "n_strats_p2_mean","n_strats_p2_std",
                        "mem_type"])

            for it in range(1, int(self.args.iters) + 1):
                t0 = time.time()

                for j in range(self.J):
                    opp2 = self._opp_restricted_set(self.F2, self.A2, j)
                    self._random_init_policy(self.A1[j])
                    _, v1 = self._solve_urr_p1(self.A1[j], self.optA1[j], opp2)

                    opp1 = self._opp_restricted_set(self.F1, self.A1, j)
                    self._random_init_policy(self.A2[j])
                    _, v2 = self._solve_urr_p2(self.A2[j], self.optA2[j], opp1)

                    if j == 0:
                        self.low_hist_p1.append(float(v1))
                        self.low_hist_p2.append(float(v2))

                if len(self.F1) < self.max_k and self._plateau(self.low_hist_p1):
                    snap = self._snapshot_active(self.A1[0])
                    self.F1.append(snap)
                    self.A1 = self.A1[1:] + [BernoulliPolicy(init_std=self.args.init_std).to(self.device)]
                    self.optA1 = self.optA1[1:] + [torch.optim.Adam(self.A1[-1].parameters(), lr=self.args.theta_lr)]
                    self._random_init_policy(self.A1[-1])
                    self.low_hist_p1.clear()

                if len(self.F2) < self.max_k and self._plateau(self.low_hist_p2):
                    snap = self._snapshot_active(self.A2[0])
                    self.F2.append(snap)
                    self.A2 = self.A2[1:] + [BernoulliPolicy(init_std=self.args.init_std).to(self.device)]
                    self.optA2 = self.optA2[1:] + [torch.optim.Adam(self.A2[-1].parameters(), lr=self.args.theta_lr)]
                    self._random_init_policy(self.A2[-1])
                    self.low_hist_p2.clear()

                nc, val, n1, n2 = self._nashconv_on_meta()

                dt = time.time() - t0
                mem, mtype = _mem_mb()
                mem_type_rec = mtype

                hist_nc.append(nc)
                hist_val.append(val)
                hist_dt.append(dt)
                hist_mem.append(mem)
                hist_n1.append(n1)
                hist_n2.append(n2)

                print(f"[EPSRO] iter {it}/{self.args.iters} | meta(P1={n1}, P2={n2}) | "
                      f"NashConv={nc:.5f} val={val:+.5f} | {dt:.2f}s mem={mem:.1f}MB")

                w.writerow([
                    it,
                    f"{dt:.6f}",  f"{0.0:.6f}",
                    f"{mem:.6f}", f"{0.0:.6f}",
                    f"{val:+.6f}", f"{0.0:.6f}",
                    f"{nc:.6f}",  f"{0.0:.6f}",
                    f"{float(n1):.6f}", f"{0.0:.6f}",
                    f"{float(n2):.6f}", f"{0.0:.6f}",
                    mtype
                ])
                fcsv.flush()

        return {
            "nc":  np.array(hist_nc, dtype=np.float64),
            "val": np.array(hist_val, dtype=np.float64),
            "dt":  np.array(hist_dt, dtype=np.float64),
            "mem": np.array(hist_mem, dtype=np.float64),
            "n1":  np.array(hist_n1, dtype=np.float64),
            "n2":  np.array(hist_n2, dtype=np.float64),
            "mem_type": mem_type_rec,
        }



def _save_meanstd_csv(path, T, mem_type,
                      dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std,
                      n1_mean, n1_std, n2_mean, n2_std):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["iter",
                    "dt_mean","dt_std",
                    "mem_mb_mean","mem_mb_std",
                    "mix_ev_mean","mix_ev_std",
                    "nashconv_mean","nashconv_std",
                    "n_strats_p1_mean","n_strats_p1_std",
                    "n_strats_p2_mean","n_strats_p2_std",
                    "mem_type"])
        for i in range(T):
            w.writerow([i+1,
                        f"{dt_mean[i]:.6f}", f"{dt_std[i]:.6f}",
                        f"{mem_mean[i]:.6f}", f"{mem_std[i]:.6f}",
                        f"{ev_mean[i]:+.6f}", f"{ev_std[i]:.6f}",
                        f"{nc_mean[i]:.6f}", f"{nc_std[i]:.6f}",
                        f"{n1_mean[i]:.6f}", f"{n1_std[i]:.6f}",
                        f"{n2_mean[i]:.6f}", f"{n2_std[i]:.6f}",
                        mem_type])

def _read_final_nashconv(csv_path: str):
    with open(csv_path, "r") as f:
        rows = list(csv.reader(f))
    last = rows[-1]
    return float(last[7])

def _welch_ttest(a: np.ndarray, b: np.ndarray):
    if ttest_ind is None:
        return float("nan"), float("nan")
    t, p = ttest_ind(a, b, equal_var=False)
    return float(t), float(p)



def main():
    args = parse_args()
    device = _pick_device(args.device)
    seeds = [int(x.strip()) for x in args.seeds.split(",") if x.strip() != ""]

    os.makedirs(args.outdir, exist_ok=True)

    per = []
    seed_csvs = []
    for s in seeds:
        r = EPSRORunner(args, seed=s, device=device)
        csv_path = os.path.join(args.outdir, f"{args.csv_base}_seed{s}.csv")
        seed_csvs.append(csv_path)
        per.append(r.run(csv_path, header=(s == seeds[0])))

    T = int(args.iters)
    nc_stack = np.stack([rec["nc"] for rec in per], 0)
    ev_stack = np.stack([rec["val"] for rec in per], 0)
    dt_stack = np.stack([rec["dt"] for rec in per], 0)
    mem_stack = np.stack([rec["mem"] for rec in per], 0)
    n1_stack = np.stack([rec["n1"] for rec in per], 0)
    n2_stack = np.stack([rec["n2"] for rec in per], 0)

    nc_mean, nc_std = nc_stack.mean(0), nc_stack.std(0, ddof=1)
    ev_mean, ev_std = ev_stack.mean(0), ev_stack.std(0, ddof=1)
    dt_mean, dt_std = dt_stack.mean(0), dt_stack.std(0, ddof=1)
    mem_mean, mem_std = mem_stack.mean(0), mem_stack.std(0, ddof=1)
    n1_mean, n1_std = n1_stack.mean(0), n1_stack.std(0, ddof=1)
    n2_mean, n2_std = n2_stack.mean(0), n2_stack.std(0, ddof=1)

    agg_csv = os.path.join(args.outdir, f"{args.csv_base}_meanstd.csv")
    _save_meanstd_csv(agg_csv, T, per[0]["mem_type"],
                      dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std,
                      n1_mean, n1_std, n2_mean, n2_std)
    print(f"[AGG] wrote mean/std CSV: {agg_csv}")
    print(f"[AGG] final (iter {T}) — EV mean={ev_mean[-1]:+.4f} ± {ev_std[-1]:.4f} | "
          f"NashConv mean={nc_mean[-1]:.4f} ± {nc_std[-1]:.4f} | "
          f"P1 strats ~{n1_mean[-1]:.2f}, P2 strats ~{n2_mean[-1]:.2f}")

    if args.ttest_against_glob is not None:
        base_paths = sorted(glob.glob(args.ttest_against_glob))
        n = min(len(base_paths), len(seed_csvs))
        a = np.array([_read_final_nashconv(p) for p in seed_csvs[:n]], dtype=np.float64)
        b = np.array([_read_final_nashconv(p) for p in base_paths[:n]], dtype=np.float64)
        t, pval = _welch_ttest(a, b)
        print(f"[TTEST] Welch t-test on final NashConv (EPSRO vs baseline): "
              f"EPSRO mean={a.mean():.4f}±{a.std(ddof=1):.4f}, "
              f"baseline mean={b.mean():.4f}±{b.std(ddof=1):.4f} | t={t:.4f}, p={pval:.4g}")

    if not args.no_plots:
        its = np.arange(1, T+1)
        plt.figure()
        plt.plot(its, nc_mean, label="NashConv (mean)")
        plt.fill_between(its, nc_mean-nc_std, nc_mean+nc_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Exploitability (EPSRO) — mean ± std")

        plt.figure()
        plt.plot(its, ev_mean, label="E[P1 payoff] (mean)")
        plt.fill_between(its, ev_mean-ev_std, ev_mean+ev_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Mixture Value (EPSRO) — mean ± std")

    ev_path_png   = os.path.join(args.outdir, f"{args.csv_base}_mixev_meanstd.png")

    plt.figure(1)
    plt.tight_layout()
    plt.savefig(nash_path_png, dpi=200)
    plt.close(1)

    plt.figure(2)
    plt.tight_layout()
    plt.savefig(ev_path_png, dpi=200)
    plt.close(2)

    print(f"[PLOTS] saved: {nash_path_png}")
    print(f"[PLOTS] saved: {ev_path_png}")



if __name__ == "__main__":
    main()
