import argparse, time, csv, os, random, math, sys
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

try:
    from scipy.stats import ttest_ind
except Exception:
    ttest_ind = None
import glob

def parse_args():
    p = argparse.ArgumentParser("α-PSRO (PPO BR) for Kuhn Poker — multi-seed")
    p.add_argument("--iters", type=int, default=40)
    p.add_argument("--alpha", type=float, default=10.0, help="selection intensity for α-Rank")
    p.add_argument("--kmax", type=int, default=0, help="0=unbounded; >0 cap per-player pool (evict least-mass)")

    p.add_argument("--ppo_rollouts", type=int, default=4000, help="episodes per BR training")
    p.add_argument("--ppo_epochs", type=int, default=10)
    p.add_argument("--ppo_batch", type=int, default=512)
    p.add_argument("--ppo_lr", type=float, default=3e-4, help="LEARNING RATE for PPO Adam optimizer")
    p.add_argument("--clip", type=float, default=0.2)
    p.add_argument("--ent_beta", type=float, default=1e-3)
    p.add_argument("--gamma", type=float, default=1.0)
    p.add_argument("--gae_lambda", type=float, default=0.95)
    p.add_argument("--max_grad_norm", type=float, default=1.0)

    p.add_argument("--prob_eps", type=float, default=1e-6)
    p.add_argument("--logit_cap", type=float, default=50.0)
    p.add_argument("--outdir", type=str, default=".")
    p.add_argument("--csv_base", type=str, default="alpha_psro")
    p.add_argument("--seeds", type=str, default="0,1,2,3,4", help="Comma-separated seeds")
    p.add_argument("--no_plots", action="store_true")
    p.add_argument("--ttest_against_glob", type=str, default=None,
                   help='Glob for baseline seed CSVs (same iters), e.g. "results/psro_kuhn_seed*.csv"')
    p.add_argument("--ttest_iter", type=int, default=0,
                   help="Iteration to test at (0=final).")
    p.add_argument("--ttest_metric", choices=["nashconv","mix_ev_p1"], default="nashconv")
    return p.parse_args()

def _seed_everything(s: int):
    random.seed(s); np.random.seed(s)

def _mem_mb():
    try:
        import psutil
        return psutil.Process().memory_info().rss / (1024**2), "rss"
    except Exception:
        try:
            import resource
            if sys.platform == "darwin":
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2), "ru_maxrss"
            else:
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0, "ru_maxrss"
        except Exception:
            return float("nan"), "n/a"

def _parse_seeds(seeds_str: str):
    out=[]
    for tok in seeds_str.split(","):
        tok=tok.strip()
        if tok: out.append(int(tok))
    return out if out else [0,1,2,3,4]

CARDS=[0,1,2]
def ev_p1_vs(p1, p2):
    b1, c1 = p1[:3], p1[3:]
    c2, b2 = p2[:3], p2[3:]
    ev=0.0
    for c in CARDS:
        for d in CARDS:
            if c==d: continue
            s4 = 2.0 if c>d else -2.0
            s2 = 1.0 if c>d else -1.0
            ev += b1[c]*( c2[d]*s4 + (1-c2[d])*1.0 ) + (1-b1[c])*( b2[d]*( c1[c]*s4 + (1-c1[c])*(-1.0) ) + (1-b2[d])*s2 )
    return ev/6.0

def enumerate_pure_policies_p1():
    out=[]
    for mask in range(64):
        b=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]; c=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(b+c, dtype=np.float64))
    return out
def enumerate_pure_policies_p2():
    out=[]
    for mask in range(64):
        c=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]; b=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(c+b, dtype=np.float64))
    return out
PURE1 = enumerate_pure_policies_p1()
PURE2 = enumerate_pure_policies_p2()

def _logistic_stable(x):
    z = np.clip(x, -50.0, 50.0)
    return 1.0/(1.0+np.exp(-z))

def alpha_rank_meta(M, alpha=10.0, tol=1e-12, max_steps=200000):
    K1, K2 = M.shape
    N = K1*K2
    if K1==0 or K2==0:
        return np.ones(K1)/max(1,K1), np.ones(K2)/max(1,K2)
    if K1==1 and K2==1:
        return np.array([1.0]), np.array([1.0])

    P = np.zeros((N,N), dtype=np.float64)
    mu_row, mu_col = (0.5 if K1>1 else 0.0), (0.5 if K2>1 else 0.0)
    if K1>1 and K2>1:
        mu_row=mu_col=0.5
    elif K1>1:
        mu_row,mu_col=1.0,0.0
    elif K2>1:
        mu_row,mu_col=0.0,1.0

    for i in range(K1):
        for j in range(K2):
            s=i*K2+j
            row_sum=0.0
            ui=M[i,j]
            if K1>1 and mu_row>0.0:
                denom=K1-1
                for ip in range(K1):
                    if ip==i: continue
                    w = mu_row * _logistic_stable(alpha*(M[ip,j]-ui)) / denom
                    P[s, ip*K2+j]+=w; row_sum+=w
            if K2>1 and mu_col>0.0:
                denom=K2-1
                for jp in range(K2):
                    if jp==j: continue
                    w = mu_col * _logistic_stable(alpha*(ui - M[i,jp])) / denom
                    P[s, i*K2+jp]+=w; row_sum+=w
            P[s,s]=max(0.0, 1.0-row_sum)

    row_sums=P.sum(axis=1, keepdims=True)
    zero_rows=(row_sums[:,0]==0.0)
    if np.any(zero_rows):
        for r in np.where(zero_rows)[0]: P[r,r]=1.0
        row_sums=P.sum(axis=1, keepdims=True)
    P = P/row_sums

    eps=1e-12
    if eps>0:
        P=(1.0-eps)*P + eps*(np.ones_like(P)/N)

    pi=np.ones(N)/N
    for _ in range(max_steps):
        new = pi @ P
        if np.linalg.norm(new-pi,1) < tol:
            pi=new; break
        pi=new
    pi=pi/(pi.sum()+1e-12)

    s1=np.zeros(K1); s2=np.zeros(K2)
    for i in range(K1):
        for j in range(K2):
            s=i*K2+j; s1[i]+=pi[s]; s2[j]+=pi[s]
    s1/=s1.sum()+1e-12; s2/=s2.sum()+1e-12
    return s1, s2

class TabularPPO:
    def __init__(self, role, args):
        self.args=args; self.role=role
        self.theta=np.zeros(6, dtype=np.float64)
        self.v=np.zeros(6, dtype=np.float64)
        self.m_t=np.zeros_like(self.theta); self.v_t=np.zeros_like(self.theta)
        self.m_v=np.zeros_like(self.v);     self.v_v=np.zeros_like(self.v)
        self.t_step=0
    def _sigm(self,x): return 1.0/(1.0+np.exp(-np.clip(x, -self.args.logit_cap, self.args.logit_cap)))
    def _logp(self,idx,a):
        p=np.clip(self._sigm(self.theta[idx]), self.args.prob_eps, 1.0-self.args.prob_eps)
        return math.log(p) if a==1 else math.log(1.0-p)
    def _entropy(self,idx):
        p=np.clip(self._sigm(self.theta[idx]), self.args.prob_eps, 1.0-self.args.prob_eps)
        return -(p*math.log(p) + (1-p)*math.log(1-p))
    def _adam_update(self, param, grad, m, v, lr, b1=0.9, b2=0.999, eps=1e-8, max_norm=None):
        self.t_step+=1
        m[:]=b1*m + (1-b1)*grad
        v[:]=b2*v + (1-b2)*(grad*grad)
        mhat=m/(1-b1**self.t_step); vhat=v/(1-b2**self.t_step)
        step=lr*mhat/(np.sqrt(vhat)+eps)
        if max_norm is not None:
            gnorm=np.linalg.norm(step)
            if gnorm>max_norm and gnorm>0: step*= (max_norm/gnorm)
        param[:]=param - step
    def ppo_update(self, batch):
        N=len(batch["idx"]); 
        if N==0: return
        a=self.args
        idx=np.array(batch["idx"]); act=np.array(batch["act"])
        logp_old=np.array(batch["logp_old"]); ret=np.array(batch["ret"])
        adv=np.array(batch["adv"]); adv=(adv-adv.mean())/(adv.std()+1e-8)
        order=np.arange(N)
        for _ in range(a.ppo_epochs):
            np.random.shuffle(order)
            for j in range(0, N, a.ppo_batch):
                jj=order[j:j+a.ppo_batch]
                cur_logp=[]; entropy=[]
                for k in jj:
                    cur_logp.append(self._logp(idx[k], act[k])); entropy.append(self._entropy(idx[k]))
                cur_logp=np.array(cur_logp); entropy=np.array(entropy)
                ratio=np.exp(cur_logp - logp_old[jj])
                surr1=ratio*adv[jj]; surr2=np.clip(ratio, 1.0-a.clip, 1.0+a.clip)*adv[jj]
                _ = -np.mean(np.minimum(surr1, surr2)) - a.ent_beta*np.mean(entropy)
                p=np.clip(self._sigm(self.theta[idx[jj]]), a.prob_eps, 1.0-a.prob_eps)
                dlogp_dtheta=(act[jj]-p); grad_ratio=np.exp(cur_logp - logp_old[jj]) * dlogp_dtheta
                grad_theta=-(1.0/len(jj))*((adv[jj])*grad_ratio)
                dHdp=-np.log(p/(1.0-p)); dpdtheta=p*(1.0-p)
                grad_theta+=-(a.ent_beta/len(jj))*(dHdp*dpdtheta)
                g_theta=np.zeros_like(self.theta)
                for kk,g in zip(idx[jj], grad_theta): g_theta[kk]+=g
                v_pred=self.v[idx[jj]]
                g_v=np.zeros_like(self.v)
                for kk,g in zip(idx[jj], (v_pred-ret[jj])/len(jj)): g_v[kk]+=g
                self._adam_update(self.theta, g_theta, self.m_t, self.v_t, lr=a.ppo_lr, max_norm=a.max_grad_norm)
                self._adam_update(self.v,     g_v,     self.m_v, self.v_v, lr=a.ppo_lr*0.5, max_norm=a.max_grad_norm)
    def probs(self):
        p=1.0/(1.0+np.exp(-np.clip(self.theta, -self.args.logit_cap, self.args.logit_cap)))
        return np.clip(p, self.args.prob_eps, 1.0-self.args.prob_eps)

def draw_cards():
    a=random.randrange(3); b=random.randrange(2); 
    if b>=a: b+=1
    return a,b

def simulate_episode(learner_role, theta_probs, opp_probs):
    c1,c2=draw_cards(); decs=[]
    if learner_role==0:
        p_b1=theta_probs[0+c1]; a_b1=1 if random.random()<p_b1 else 0; decs.append((0+c1, a_b1))
        if a_b1==1:
            p_call2=opp_probs[0+c2]; a_call2=1 if random.random()<p_call2 else 0
            r=(2.0 if c1>c2 else -2.0) if a_call2==1 else 1.0; return r,decs
        else:
            p_b2=opp_probs[3+c2]; a_b2=1 if random.random()<p_b2 else 0
            if a_b2==1:
                p_call1=theta_probs[3+c1]; a_call1=1 if random.random()<p_call1 else 0; decs.append((3+c1, a_call1))
                r=(2.0 if c1>c2 else -2.0) if a_call1==1 else -1.0; return r,decs
            else:
                r=1.0 if c1>c2 else -1.0; return r,decs
    else:
        p_b1=opp_probs[0+c1]; a_b1=1 if random.random()<p_b1 else 0
        if a_b1==1:
            p_call2=theta_probs[0+c2]; a_call2=1 if random.random()<p_call2 else 0; decs.append((0+c2, a_call2))
            r_p1=(2.0 if c1>c2 else -2.0) if a_call2==1 else 1.0; return -r_p1,decs
        else:
            p_b2=theta_probs[3+c2]; a_b2=1 if random.random()<p_b2 else 0; decs.append((3+c2, a_b2))
            if a_b2==1:
                p_call1=opp_probs[3+c1]; a_call1=1 if random.random()<p_call1 else 0
                r_p1=(2.0 if c1>c2 else -2.0) if a_call1==1 else -1.0; return -r_p1,decs
            else:
                r_p1=1.0 if c1>c2 else -1.0; return -r_p1,decs

class AlphaPSRORunner:
    def __init__(self, args, seed: int):
        self.args = args
        self.seed = seed
        self._setup_state()

    def _setup_state(self):
        _seed_everything(self.seed)
        seed_pol = np.array([0.5] * 6, dtype=np.float64)
        self.Z = [
            [seed_pol.copy()],
            [seed_pol.copy()],
        ]

    def ev_matrix(self):
        K1, K2 = len(self.Z[0]), len(self.Z[1])
        M = np.zeros((K1, K2), dtype=np.float64)
        for i in range(K1):
            for j in range(K2):
                M[i, j] = ev_p1_vs(self.Z[0][i], self.Z[1][j])
        return M

    def nashconv(self, s1, s2):
        M = self.ev_matrix()
        val = float(s1 @ M @ s2)

        mix_p2 = sum(s2[j] * self.Z[1][j] for j in range(len(self.Z[1])))
        mix_p1 = sum(s1[i] * self.Z[0][i] for i in range(len(self.Z[0])))

        br1 = max(ev_p1_vs(pi, mix_p2) for pi in PURE1)
        br2min = min(ev_p1_vs(mix_p1, pj) for pj in PURE2)
        return max(0.0, br1 - br2min), val

    def ensure_match_mix(self, learner_role, mix):
        K = len(self.Z[1 - learner_role])
        m = np.asarray(mix, dtype=np.float64).copy()

        if m.shape[0] != K:
            if m.shape[0] < K:
                m = np.pad(m, (0, K - m.shape[0]), mode="constant")
            else:
                m = m[:K]

        s = float(m.sum())
        if (not np.isfinite(s)) or s <= 0:
            return np.ones(K, dtype=np.float64) / K
        return m / s

    def collect_br_data(self, learner_role, sigma_opp, rollouts):
        sigma_opp = self.ensure_match_mix(learner_role, sigma_opp)

        data = {"idx": [], "act": [], "logp_old": [], "ret": [], "adv": []}
        agent = TabularPPO(learner_role, self.args)

        for _ in range(rollouts):
            j = np.random.choice(len(self.Z[1 - learner_role]), p=sigma_opp)
            opp = self.Z[1 - learner_role][j]

            p_now = agent.probs()
            r, decs = simulate_episode(learner_role, p_now, opp)

            for (idx, a) in decs:
                logp = agent._logp(idx, a)
                v_pred = agent.v[idx]
                adv = r - v_pred
                data["idx"].append(idx)
                data["act"].append(a)
                data["logp_old"].append(logp)
                data["ret"].append(r)
                data["adv"].append(adv)

        return agent, data

    def train_br_with_ppo(self, role, sigma_opp):
        sigma_opp = self.ensure_match_mix(role, sigma_opp)
        agent, batch = self.collect_br_data(role, sigma_opp, self.args.ppo_rollouts)
        if len(batch["idx"]) == 0:
            return np.full(6, 0.5, dtype=np.float64)
        agent.ppo_update(batch)
        return agent.probs()

    def maybe_evict_least_mass(self, p, mix, protect_seed=True):
        if self.args.kmax <= 0:
            return
        if len(self.Z[p]) < self.args.kmax:
            return

        m = np.asarray(mix, dtype=np.float64)
        if m.shape[0] != len(self.Z[p]):
            if m.shape[0] < len(self.Z[p]):
                m = np.pad(m, (0, len(self.Z[p]) - m.shape[0]), mode="constant")
            else:
                m = m[:len(self.Z[p])]
            s = float(m.sum())
            m = (np.ones(len(self.Z[p])) / len(self.Z[p])) if (not np.isfinite(s) or s <= 0) else (m / s)

        start = 1 if protect_seed and len(self.Z[p]) > 1 else 0
        idx = start + int(np.argmin(m[start:]))
        self.Z[p].pop(idx)

    def add_anchor(self, p, vec):
        vec = np.asarray(vec, dtype=np.float64)
        vec = np.clip(vec, self.args.prob_eps, 1.0 - self.args.prob_eps)
        self.Z[p].append(vec.copy())

    def run(self, csv_path: str, header=True):
        os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)
        if header:
            print(f"[α-PSRO] seed={self.seed} | PPO LR={self.args.ppo_lr} | alpha={self.args.alpha}")

        hist_nc, hist_val, hist_dt, hist_mem, hist_n1, hist_n2 = [], [], [], [], [], []
        mem_type_rec = "n/a"

        with open(csv_path, "w", newline="") as fcsv:
            w = csv.writer(fcsv)
            w.writerow(["iter", "timestamp", "n_strats_p1", "n_strats_p2",
                        "time_sec", "mem_mb", "mem_type", "nashconv", "mix_ev_p1"])

            for it in range(1, self.args.iters + 1):
                t0 = time.time()

                M = self.ev_matrix()

                s1, s2 = alpha_rank_meta(M, alpha=self.args.alpha)

                s1 = np.asarray(s1, dtype=np.float64)
                s2 = np.asarray(s2, dtype=np.float64)
                if s1.shape[0] != len(self.Z[0]):
                    s1 = np.ones(len(self.Z[0]), dtype=np.float64) / len(self.Z[0])
                else:
                    s1 = s1 / max(1e-12, s1.sum())
                if s2.shape[0] != len(self.Z[1]):
                    s2 = np.ones(len(self.Z[1]), dtype=np.float64) / len(self.Z[1])
                else:
                    s2 = s2 / max(1e-12, s2.sum())

                self.maybe_evict_least_mass(0, s1, protect_seed=True)
                self.maybe_evict_least_mass(1, s2, protect_seed=True)

                M = self.ev_matrix()
                s1, s2 = alpha_rank_meta(M, alpha=self.args.alpha)
                s1 = np.asarray(s1, dtype=np.float64); s2 = np.asarray(s2, dtype=np.float64)
                s1 = (np.ones(len(self.Z[0])) / len(self.Z[0])) if (s1.shape[0] != len(self.Z[0])) else (s1 / max(1e-12, s1.sum()))
                s2 = (np.ones(len(self.Z[1])) / len(self.Z[1])) if (s2.shape[0] != len(self.Z[1])) else (s2 / max(1e-12, s2.sum()))

                br1 = self.train_br_with_ppo(role=0, sigma_opp=s2)
                br2 = self.train_br_with_ppo(role=1, sigma_opp=s1)
                self.add_anchor(0, br1)
                self.add_anchor(1, br2)

                M = self.ev_matrix()
                s1, s2 = alpha_rank_meta(M, alpha=self.args.alpha)
                s1 = np.asarray(s1, dtype=np.float64)
                s2 = np.asarray(s2, dtype=np.float64)
                if s1.shape[0] != len(self.Z[0]): s1 = np.ones(len(self.Z[0])) / len(self.Z[0])
                else: s1 = s1 / max(1e-12, s1.sum())
                if s2.shape[0] != len(self.Z[1]): s2 = np.ones(len(self.Z[1])) / len(self.Z[1])
                else: s2 = s2 / max(1e-12, s2.sum())

                nc, val = self.nashconv(s1, s2)

                dt = time.time() - t0
                mem, mtype = _mem_mb()
                mem_type_rec = mtype

                hist_nc.append(nc); hist_val.append(val); hist_dt.append(dt); hist_mem.append(mem)
                hist_n1.append(len(self.Z[0])); hist_n2.append(len(self.Z[1]))

                print(f"[α-PSRO] seed={self.seed} iter {it}/{self.args.iters} | "
                      f"P1={len(self.Z[0])} P2={len(self.Z[1])} | NashConv={nc:.5f} val={val:+.5f} | "
                      f"{dt:.2f}s mem={mem:.1f}MB")

                w.writerow([it, time.strftime("%Y-%m-%d %H:%M:%S"),
                            len(self.Z[0]), len(self.Z[1]),
                            f"{dt:.3f}", f"{mem:.2f}", mtype,
                            f"{nc:.6f}", f"{val:+.6f}"])
                fcsv.flush()

        return {
            "nc": np.array(hist_nc),
            "val": np.array(hist_val),
            "dt": np.array(hist_dt),
            "mem": np.array(hist_mem),
            "n1": np.array(hist_n1),
            "n2": np.array(hist_n2),
            "mem_type": mem_type_rec
        }

def _save_meanstd_csv(path, T, mem_type, dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std, n1_mean, n1_std, n2_mean, n2_std):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path,"w",newline="") as f:
        w=csv.writer(f)
        w.writerow(["iter","dt_mean","dt_std","mem_mb_mean","mem_mb_std",
                    "mix_ev_mean","mix_ev_std","nashconv_mean","nashconv_std",
                    "n_strats_p1_mean","n_strats_p1_std","n_strats_p2_mean","n_strats_p2_std","mem_type"])
        for i in range(T):
            w.writerow([i+1, f"{dt_mean[i]:.6f}", f"{dt_std[i]:.6f}",
                        f"{mem_mean[i]:.3f}", f"{mem_std[i]:.3f}",
                        f"{ev_mean[i]:+.6f}", f"{ev_std[i]:.6f}",
                        f"{nc_mean[i]:.6f}", f"{nc_std[i]:.6f}",
                        f"{n1_mean[i]:.3f}", f"{n1_std[i]:.3f}",
                        f"{n2_mean[i]:.3f}", f"{n2_std[i]:.3f}",
                        mem_type])



def _collect_seed_files(glob_pattern: str):
    out = {}
    for pth in sorted(glob.glob(glob_pattern)):
        m = re.search(r"seed(\d+)\.csv$", os.path.basename(pth))
        if m:
            out[int(m.group(1))] = pth
    return out

def _read_metric_at_iter(csv_path: str, metric: str, it: int = 0):
    import csv as _csv
    with open(csv_path, "r") as f:
        rows = list(_csv.reader(f))
    header = rows[0]
    if metric not in header:
        raise ValueError(f"Metric '{metric}' not in CSV header: {header}")
    idx = header.index(metric)
    if it == 0:
        return float(rows[-1][idx])
    for r in rows[1:]:
        if int(float(r[0])) == it:
            return float(r[idx])
    raise ValueError(f"Iteration {it} not found in {csv_path}")

def _welch_ttest(a: np.ndarray, b: np.ndarray):
    if ttest_ind is None:
        return float("nan"), float("nan")
    t, p = ttest_ind(a, b, equal_var=False)
    return float(t), float(p)

def main():
    args=parse_args()
    seeds=_parse_seeds(args.seeds)
    os.makedirs(args.outdir, exist_ok=True)

    per=[]
    for s in seeds:
        r=AlphaPSRORunner(args, seed=s)
        csv_path=os.path.join(args.outdir, f"{args.csv_base}_seed{s}.csv")
        per.append(r.run(csv_path, header=(s==seeds[0])))

    T=args.iters
    for rec in per: assert len(rec["nc"])==T and len(rec["val"])==T

    nc_stack=np.stack([rec["nc"] for rec in per],0)
    ev_stack=np.stack([rec["val"] for rec in per],0)
    dt_stack=np.stack([rec["dt"] for rec in per],0)
    mem_stack=np.stack([rec["mem"] for rec in per],0)
    n1_stack=np.stack([rec["n1"] for rec in per],0)
    n2_stack=np.stack([rec["n2"] for rec in per],0)

    nc_mean,nc_std=nc_stack.mean(0), nc_stack.std(0,ddof=1)
    ev_mean,ev_std=ev_stack.mean(0), ev_stack.std(0,ddof=1)
    dt_mean,dt_std=dt_stack.mean(0), dt_stack.std(0,ddof=1)
    mem_mean,mem_std=mem_stack.mean(0), mem_stack.std(0,ddof=1)
    n1_mean,n1_std=n1_stack.mean(0), n1_stack.std(0,ddof=1)
    n2_mean,n2_std=n2_stack.mean(0), n2_stack.std(0,ddof=1)

    agg_csv=os.path.join(args.outdir, f"{args.csv_base}_meanstd.csv")
    _save_meanstd_csv(agg_csv, T, per[0]["mem_type"], dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std, n1_mean, n1_std, n2_mean, n2_std)
    print(f"[AGG] wrote mean/std CSV: {agg_csv}")
    print(f"[AGG] final (iter {T}) — EV mean={ev_mean[-1]:+.4f} ± {ev_std[-1]:.4f} | "
          f"NashConv mean={nc_mean[-1]:.4f} ± {nc_std[-1]:.4f} | "
          f"P1 strats ~{n1_mean[-1]:.2f}, P2 strats ~{n2_mean[-1]:.2f}")

    if args.ttest_against_glob is not None:
        base_map = _collect_seed_files(args.ttest_against_glob)
        our_map = {s: os.path.join(args.outdir, f"{args.csv_base}_seed{s}.csv") for s in seeds}
        common = [s for s in seeds if (s in base_map) and os.path.exists(our_map[s])]
        if len(common) < 2:
            print(f"[TTEST] Not enough common seeds for t-test. Found {len(common)} common seeds.")
        else:
            a = np.array([_read_metric_at_iter(our_map[s], args.ttest_metric, args.ttest_iter) for s in common], dtype=np.float64)
            b = np.array([_read_metric_at_iter(base_map[s], args.ttest_metric, args.ttest_iter) for s in common], dtype=np.float64)
            t, p = _welch_ttest(a, b)
            if math.isnan(p):
                print("[TTEST] scipy not available; install scipy to compute p-values.")
            print(f"[TTEST] Welch t-test on {args.ttest_metric} at iter {args.ttest_iter or args.iters} (α-PSRO vs baseline) | "
                  f"α-PSRO mean={a.mean():.4f}±{a.std(ddof=1):.4f}, baseline mean={b.mean():.4f}±{b.std(ddof=1):.4f} | t={t:.4f}, p={p:.4g}")

    if not args.no_plots:
        its=np.arange(1,T+1)
        plt.figure(); plt.plot(its, nc_mean, label="NashConv (mean)")
        plt.fill_between(its, nc_mean-nc_std, nc_mean+nc_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Exploitability (α-PSRO) — mean ± std")

        plt.figure(); plt.plot(its, ev_mean, label="E[P1 payoff] (mean)")
        plt.fill_between(its, ev_mean-ev_std, ev_mean+ev_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Mixture Value (α-PSRO) — mean ± std")

    nash_path_png = os.path.join(args.outdir, f"{args.csv_base}_nashconv_meanstd.png")
    ev_path_png   = os.path.join(args.outdir, f"{args.csv_base}_mixev_meanstd.png")

    figs = [plt.figure(n) for n in plt.get_fignums()]
    if len(figs) >= 2:
        figs[-2].tight_layout()
        figs[-2].savefig(nash_path_png, dpi=200)
        plt.close(figs[-2])

        figs[-1].tight_layout()
        figs[-1].savefig(ev_path_png, dpi=200)
        plt.close(figs[-1])
    else:
        plt.gcf().tight_layout()
        plt.gcf().savefig(nash_path_png, dpi=200)
        plt.close(plt.gcf())

    print(f"[PLOTS] saved: {nash_path_png}")
    print(f"[PLOTS] saved: {ev_path_png}")


if __name__=="__main__":
    main()
