import argparse, time, csv, os, random, math, sys
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# --------------------------- Args ---------------------------
def parse_args():
    p = argparse.ArgumentParser("GEMS-OMWU (constant, hardened) for Kuhn — multi-seed")
    p.add_argument("--kmax", type=int, default=8)
    p.add_argument("--iters", type=int, default=40)

    # OMWU meta
    p.add_argument("--eta", type=float, default=0.08)
    p.add_argument("--eta_sched", choices=["const","sqrt","harmonic"], default="harmonic")
    p.add_argument("--ema", type=float, default=0.0)

    # Oracle replacement
    p.add_argument("--pool_mut", type=int, default=2)
    p.add_argument("--pool_rand", type=int, default=1)
    p.add_argument("--replace", choices=["least_mass","worst_ev"], default="least_mass")

    # ABR-TR (constant work)
    p.add_argument("--abr_steps", type=int, default=30)
    p.add_argument("--abr_lr", type=float, default=5e-4)
    p.add_argument("--beta_kl", type=float, default=1e-2)
    p.add_argument("--tau", type=float, default=1.0)

    # Guardrails
    p.add_argument("--clip_grad", type=float, default=1.0)
    p.add_argument("--logit_cap", type=float, default=50.0)
    p.add_argument("--prob_eps", type=float, default=1e-6)
    p.add_argument("--mwu_grad_cap", type=float, default=None)

    # I/O
    p.add_argument("--outdir", type=str, default=".")
    p.add_argument("--csv_base", type=str, default="gems_kuhn_const")
    p.add_argument("--device", choices=["auto","cpu","cuda"], default="auto")
    p.add_argument("--seeds", type=str, default="0,1,2,3,4",
                   help="Comma-separated list of integer seeds, default '0,1,2,3,4'")
    p.add_argument("--no_plots", action="store_true")
    return p.parse_args()

# --------------------------- Utilities ---------------------------
def _seed_everything(s: int):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)

def _pick_device(argdev: str):
    if argdev == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(argdev)

def _mem_mb():
    try:
        import psutil
        return psutil.Process().memory_info().rss / (1024**2), "rss"
    except Exception:
        try:
            import resource
            if sys.platform == "darwin":
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2), "ru_maxrss"
            else:
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0, "ru_maxrss"
        except Exception:
            return float("nan"), "n/a"

# --------------------------- Game constants & EV ---------------------------
CARDS = [0,1,2]  # Kuhn ranks 0<1<2

def ev_p1_vs_numpy(p1_vec6, p2_vec6):
    # p1_vec6 = [b1_0,b1_1,b1_2,c1_0,c1_1,c1_2]
    # p2_vec6 = [c2_0,c2_1,c2_2,b2_0,b2_1,b2_2]
    b1, c1 = p1_vec6[:3], p1_vec6[3:]
    c2, b2 = p2_vec6[:3], p2_vec6[3:]
    ev = 0.0
    for c1i in CARDS:
        for c2i in CARDS:
            if c1i==c2i: continue
            s4 = 2.0 if c1i>c2i else -2.0
            s2 = 1.0 if c1i>c2i else -1.0
            term_B = b1[c1i]*( c2[c2i]*s4 + (1.0-c2[c2i])*1.0 )
            term_C = (1.0-b1[c1i])*( b2[c2i]*( c1[c1i]*s4 + (1.0-c1[c1i])*(-1.0) )
                                      + (1.0-b2[c2i])*s2 )
            ev += term_B + term_C
    return ev/6.0

def ev_p1_vs_torch(p1_probs6, p2_probs6, device):
    # p1_probs6, p2_probs6: torch tensors shape [6]
    b1, c1 = p1_probs6[:3], p1_probs6[3:]
    c2, b2 = p2_probs6[:3], p2_probs6[3:]
    ev = torch.zeros([], dtype=torch.float32, device=device)
    for c1i in CARDS:
        for c2i in CARDS:
            if c1i==c2i: continue
            s4 = 2.0 if c1i>c2i else -2.0
            s2 = 1.0 if c1i>c2i else -1.0
            term_B = b1[c1i]*( c2[c2i]*s4 + (1.0-c2[c2i])*1.0 )
            term_C = (1.0-b1[c1i])*( b2[c2i]*( c1[c1i]*s4 + (1.0-c1[c1i])*(-1.0) )
                                      + (1.0-b2[c2i])*s2 )
            ev = ev + term_B + term_C
    return ev/6.0

# --------------------------- Nets ---------------------------
def _init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight, gain=1.0); nn.init.zeros_(m.bias)

class Gen(nn.Module):
    def __init__(self, zdim=8):
        super().__init__()
        self.p1 = nn.Sequential(nn.Linear(zdim,64), nn.ReLU(), nn.Linear(64,32), nn.ReLU(), nn.Linear(32,6))
        self.p2 = nn.Sequential(nn.Linear(zdim,64), nn.ReLU(), nn.Linear(64,32), nn.ReLU(), nn.Linear(32,6))
        self.apply(_init_weights)
    def p1_probs(self, z, tau):
        return torch.sigmoid(self.p1(z)/tau)
    def p2_probs(self, z, tau):
        return torch.sigmoid(self.p2(z)/tau)

# --------------------------- Pure policies for NashConv ---------------------------
def _enumerate_pure_policies_p1():
    out=[]
    for mask in range(64):
        b=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]
        c=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(b+c, dtype=np.float32))
    return out
def _enumerate_pure_policies_p2():
    out=[]
    for mask in range(64):
        c=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]
        b=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(c+b, dtype=np.float32))
    return out

PURE1=_enumerate_pure_policies_p1()
PURE2=_enumerate_pure_policies_p2()

# --------------------------- Runner (state is isolated per seed) ---------------------------
class KuhnGEMSRunner:
    ZDIM = 8

    def __init__(self, args, seed: int, device: torch.device):
        self.args = args
        self.seed = seed
        self.device = device
        self._setup_state()

    def _setup_state(self):
        _seed_everything(self.seed)

        # Model/opt
        self.gen = Gen(self.ZDIM).to(self.device)
        self.opt = torch.optim.Adam(self.gen.parameters(), lr=self.args.abr_lr)

        # Fixed-size populations & meta logs
        self.K = self.args.kmax
        self.Z = [
            np.random.normal(0,1,size=(self.K,self.ZDIM)).astype(np.float32),
            np.random.normal(0,1,size=(self.K,self.ZDIM)).astype(np.float32)
        ]
        self.LOGS = [ np.zeros(self.K, dtype=np.float64), np.zeros(self.K, dtype=np.float64) ]
        self.G_PREV= [ np.zeros(self.K, dtype=np.float64), np.zeros(self.K, dtype=np.float64) ]

        # EMA caches
        self._VV=None; self._RR=None

    # ---------- helpers ----------
    def softmax_np(self, x):
        cap = float(self.args.logit_cap)
        x = np.clip(x, -cap, cap)
        z = x - np.max(x)
        e = np.exp(z)
        return e / (e.sum() + 1e-12)

    def sigma_list(self):
        return [self.softmax_np(self.LOGS[0]), self.softmax_np(self.LOGS[1])]

    def eta_t(self, eta0, t):
        sched=self.args.eta_sched
        if sched=="const": return eta0
        if sched=="sqrt": return eta0/max(1.0, math.sqrt(t))
        if sched=="harmonic": return eta0/(1.0+0.5*t)
        return eta0

    def _scrub_probs_np(self, p):
        p = np.nan_to_num(p, nan=0.5, posinf=1.0, neginf=0.0)
        eps = float(self.args.prob_eps)
        return np.clip(p, eps, 1.0-eps)

    @torch.no_grad()
    def prob_rows(self, role, Zblock):
        z = torch.tensor(Zblock, dtype=torch.float32, device=self.device)
        probs = self.gen.p1_probs(z, self.args.tau) if role==0 else self.gen.p2_probs(z, self.args.tau)
        p = probs.detach().cpu().numpy().astype(np.float32)
        for i in range(p.shape[0]):
            p[i] = self._scrub_probs_np(p[i])
        return p  # [K,6]

    @torch.no_grad()
    def ev_matrix(self):
        P1 = self.prob_rows(0, self.Z[0]); P2 = self.prob_rows(1, self.Z[1])
        M = np.zeros((self.K,self.K), dtype=np.float64)
        for i in range(self.K):
            for j in range(self.K):
                M[i,j] = ev_p1_vs_numpy(P1[i], P2[j])
        return M, P1, P2

    # ---------- meta-estimate & OMWU ----------
    def meta_estimate_exact(self, it):
        M,_,_ = self.ev_matrix()
        s1,s2 = self.sigma_list()
        v0 = (M @ s2).astype(np.float64)
        v1 = (-(s1 @ M)).astype(np.float64)
        r  = np.array([float(s1 @ M @ s2), -float(s1 @ M @ s2)], dtype=np.float64)
        if self.args.ema>0.0:
            beta=self.args.ema
            if self._VV is None: self._VV=[v0.copy(),v1.copy()]; self._RR=r.copy()
            self._VV[0]=(1-beta)*self._VV[0]+beta*v0; self._VV[1]=(1-beta)*self._VV[1]+beta*v1
            self._RR=(1-beta)*self._RR+beta*r
            return [self._VV[0].copy(), self._VV[1].copy()], self._RR.copy()
        return [v0,v1], r

    def mwu_update_omwu(self, vhat, rbar, it):
        eta=self.eta_t(self.args.eta,it)
        cap=float(self.args.mwu_grad_cap) if self.args.mwu_grad_cap is not None else None
        for p in (0,1):
            gains = np.asarray(vhat[p]) - float(rbar[p])
            grad  = 2.0*gains - self.G_PREV[p]
            if cap is not None: grad = np.clip(grad, -cap, cap)
            self.LOGS[p]+= eta*grad
            self.LOGS[p]= np.clip(self.LOGS[p], -self.args.logit_cap, self.args.logit_cap)
            self.G_PREV[p]= gains

    # ---------- Oracle (replacement, constant K) ----------
    def _pick_evict_index(self, p, P1, P2):
        if self.args.replace=="least_mass":
            return int(np.argmin(self.sigma_list()[p]))
        s1,s2 = self.sigma_list()
        if p==0:
            vals = np.array([sum(s2[j]*ev_p1_vs_numpy(P1[i], P2[j]) for j in range(self.K)) for i in range(self.K)], dtype=np.float64)
        else:
            vals = np.array([sum(s1[i]*(-ev_p1_vs_numpy(P1[i], P2[j])) for i in range(self.K)) for j in range(self.K)], dtype=np.float64)
        return int(np.argmin(vals))

    @torch.no_grad()
    def oracle_replace(self, P1, P2):
        s1,s2 = self.sigma_list()
        for p in (0,1):
            base = self.Z[p][random.randrange(self.K)]
            cands=[]
            for _ in range(self.args.pool_mut):
                cands.append((base + np.random.normal(0,0.25,size=(self.ZDIM,)).astype(np.float32)))
            for _ in range(self.args.pool_rand):
                cands.append(np.random.normal(0,1,size=(self.ZDIM,)).astype(np.float32))

            best_val=-1e18; best=None
            if p==0:
                for zc in cands:
                    probs = self.gen.p1_probs(torch.tensor(zc, device=self.device).unsqueeze(0), self.args.tau)[0]
                    probs = torch.nan_to_num(probs, nan=0.5).clamp(self.args.prob_eps, 1-self.args.prob_eps)
                    val=0.0
                    for j in range(self.K):
                        val += s2[j]*ev_p1_vs_numpy(probs.detach().cpu().numpy(), P2[j])
                    if val>best_val:
                        best_val=val; best=zc
            else:
                for zc in cands:
                    probs = self.gen.p2_probs(torch.tensor(zc, device=self.device).unsqueeze(0), self.args.tau)[0]
                    probs = torch.nan_to_num(probs, nan=0.5).clamp(self.args.prob_eps, 1-self.args.prob_eps)
                    val=0.0
                    for i in range(self.K):
                        val += s1[i]*(-ev_p1_vs_numpy(P1[i], probs.detach().cpu().numpy()))
                    if val>best_val:
                        best_val=val; best=zc

            evict = self._pick_evict_index(p, P1, P2)
            self.Z[p][evict]= best
            self.LOGS[p][evict] = np.min(self.LOGS[p]) - 1.0
            self.G_PREV[p][evict]= 0.0

    # ---------- ABR-TR (torch, constant work) ----------
    def _bern_kl(self, p, q, eps=1e-12):
        p=torch.clamp(p,eps,1-eps); q=torch.clamp(q,eps,1-eps)
        return (q*(torch.log(q)-torch.log(p)) + (1-q)*(torch.log(1-q)-torch.log(1-p))).sum()

    def _finite_params(self):
        ok=True
        with torch.no_grad():
            for p in self.gen.parameters():
                if not torch.isfinite(p).all():
                    ok=False; break
        return ok

    def abr_tr_step(self, P1_np, P2_np):
        s1,s2 = self.sigma_list()
        i = int(np.argmin(s1)); j = int(np.argmin(s2))
        z1 = torch.tensor(self.Z[0][i], dtype=torch.float32, device=self.device).unsqueeze(0)
        z2 = torch.tensor(self.Z[1][j], dtype=torch.float32, device=self.device).unsqueeze(0)

        with torch.no_grad():
            p1_snap = self.gen.p1_probs(z1, self.args.tau).detach()[0]
            p2_snap = self.gen.p2_probs(z2, self.args.tau).detach()[0]

        P1 = torch.tensor(P1_np, dtype=torch.float32, device=self.device)
        P2 = torch.tensor(P2_np, dtype=torch.float32, device=self.device)
        s1_t = torch.tensor(s1, dtype=torch.float32, device=self.device)
        s2_t = torch.tensor(s2, dtype=torch.float32, device=self.device)

        saved = {k:v.clone() for k,v in self.gen.state_dict().items()}
        lr_decay = 1.0

        for _ in range(self.args.abr_steps):
            p1_cur = self.gen.p1_probs(z1, self.args.tau)[0]
            p2_cur = self.gen.p2_probs(z2, self.args.tau)[0]
            p1_cur = torch.nan_to_num(p1_cur, nan=0.5).clamp(self.args.prob_eps, 1-self.args.prob_eps)
            p2_cur = torch.nan_to_num(p2_cur, nan=0.5).clamp(self.args.prob_eps, 1-self.args.prob_eps)

            ev1 = torch.zeros([], dtype=torch.float32, device=self.device)
            for jj in range(self.K):
                ev1 = ev1 + s2_t[jj]*ev_p1_vs_torch(p1_cur, P2[jj], self.device)

            ev2 = torch.zeros([], dtype=torch.float32, device=self.device)
            for ii in range(self.K):
                ev2 = ev2 + s1_t[ii]*(-ev_p1_vs_torch(P1[ii], p2_cur, self.device))

            kl1 = self._bern_kl(p1_cur, p1_snap)
            kl2 = self._bern_kl(p2_cur, p2_snap)
            loss = -(ev1 + ev2) + self.args.beta_kl*(kl1 + kl2)

            self.opt.zero_grad(set_to_none=True); loss.backward()
            if self.args.clip_grad>0: torch.nn.utils.clip_grad_norm_(self.gen.parameters(), self.args.clip_grad)
            for g in self.opt.param_groups: g["lr"] = self.args.abr_lr * lr_decay
            self.opt.step()

            if not self._finite_params():
                self.gen.load_state_dict(saved)
                lr_decay *= 0.5

    # ---------- NashConv ----------
    @torch.no_grad()
    def nashconv_current(self):
        M,P1,P2 = self.ev_matrix()
        s1,s2 = self.sigma_list()
        val = float(s1 @ M @ s2)
        mix_p2 = sum(s2[j]*P2[j] for j in range(self.K))
        mix_p1 = sum(s1[i]*P1[i] for i in range(self.K))
        br1 = max(ev_p1_vs_numpy(pi, mix_p2) for pi in PURE1)
        br2min = min(ev_p1_vs_numpy(mix_p1, pj) for pj in PURE2)
        nc = max(0.0, br1 - br2min)
        return nc, val

    # ---------- main loop (single seed) ----------
    def run(self, csv_path: str, print_header=True):
        os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)
        if print_header:
            dev_info = f"{self.device.type}" + (f" gpu={torch.cuda.get_device_name(0)}" if self.device.type=='cuda' else "")
            print(f"[GEMS] K={self.K} device={dev_info} | seed={self.seed}")

        hist_ev1, hist_nc, hist_dt, hist_mem = [], [], [], []
        mem_type_rec = "n/a"

        with open(csv_path, "w", newline="") as fcsv:
            w = csv.writer(fcsv)
            w.writerow(["iter","timestamp","time_sec","mem_mb","mem_type","mix_ev_p1","nashconv"])
            for it in range(1, self.args.iters+1):
                t0=time.time()
                vhat, rbar = self.meta_estimate_exact(it)
                self.mwu_update_omwu(vhat, rbar, it)

                M,P1,P2 = self.ev_matrix()
                self.oracle_replace(P1, P2)
                self.abr_tr_step(P1, P2)

                nc, mix_ev = self.nashconv_current()
                dt=time.time()-t0; mem,mtype=_mem_mb()
                mem_type_rec = mtype

                hist_ev1.append(mix_ev); hist_nc.append(nc); hist_dt.append(dt); hist_mem.append(mem)

                print(f"[seed {self.seed}] iter {it}/{self.args.iters} t={dt:.2f}s mem={mem:.1f}MB mixEV={mix_ev:+.4f} NashConv={nc:.4f}")
                w.writerow([it, time.strftime("%Y-%m-%d %H:%M:%S"), f"{dt:.3f}", f"{mem:.2f}", mtype,
                            f"{mix_ev:+.6f}", f"{nc:.6f}"])
                fcsv.flush()

        return {
            "ev": np.array(hist_ev1, dtype=np.float64),
            "nc": np.array(hist_nc, dtype=np.float64),
            "dt": np.array(hist_dt, dtype=np.float64),
            "mem": np.array(hist_mem, dtype=np.float64),
            "mem_type": mem_type_rec
        }

# --------------------------- Aggregate helpers ---------------------------
def _parse_seeds(seeds_str: str):
    out=[]
    for tok in seeds_str.split(","):
        tok = tok.strip()
        if not tok: continue
        out.append(int(tok))
    if not out:
        out = [0,1,2,3,4]
    return out

def _save_meanstd_csv(path, iters, mem_type, dt_mean, dt_std, mem_mean, mem_std, ev_mean, ev_std, nc_mean, nc_std):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["iter","dt_mean","dt_std","mem_mb_mean","mem_mb_std","mix_ev_mean","mix_ev_std","nashconv_mean","nashconv_std","mem_type"])
        for i in range(iters):
            w.writerow([i+1,
                        f"{dt_mean[i]:.6f}", f"{dt_std[i]:.6f}",
                        f"{mem_mean[i]:.3f}", f"{mem_std[i]:.3f}",
                        f"{ev_mean[i]:+.6f}", f"{ev_std[i]:.6f}",
                        f"{nc_mean[i]:.6f}", f"{nc_std[i]:.6f}",
                        mem_type])

# --------------------------- Entry ---------------------------
def main():
    args = parse_args()
    seeds = _parse_seeds(args.seeds)
    device = _pick_device(args.device)
    os.makedirs(args.outdir, exist_ok=True)

    # Collect per-seed runs
    per_seed = []
    for s in seeds:
        runner = KuhnGEMSRunner(args, seed=s, device=device)
        seed_csv = os.path.join(args.outdir, f"{args.csv_base}_seed{s}.csv")
        per_seed.append(runner.run(seed_csv, print_header=(s==seeds[0])))

    # Sanity on length
    T = args.iters
    for rec in per_seed:
        assert len(rec["ev"])==T and len(rec["nc"])==T, "All seeds must have identical iters"

    # Stack & aggregate
    ev_stack  = np.stack([rec["ev"]  for rec in per_seed], axis=0)  # [S,T]
    nc_stack  = np.stack([rec["nc"]  for rec in per_seed], axis=0)
    dt_stack  = np.stack([rec["dt"]  for rec in per_seed], axis=0)
    mem_stack = np.stack([rec["mem"] for rec in per_seed], axis=0)

    ev_mean, ev_std   = ev_stack.mean(axis=0), ev_stack.std(axis=0, ddof=1)
    nc_mean, nc_std   = nc_stack.mean(axis=0), nc_stack.std(axis=0, ddof=1)
    dt_mean, dt_std   = dt_stack.mean(axis=0), dt_stack.std(axis=0, ddof=1)
    mem_mean, mem_std = mem_stack.mean(axis=0), mem_stack.std(axis=0, ddof=1)

    mem_type = per_seed[0]["mem_type"]
    agg_csv = os.path.join(args.outdir, f"{args.csv_base}_meanstd.csv")
    _save_meanstd_csv(agg_csv, T, mem_type, dt_mean, dt_std, mem_mean, mem_std, ev_mean, ev_std, nc_mean, nc_std)
    print(f"[AGG] wrote mean/std CSV: {agg_csv}")

    # Quick aggregate printout
    print(f"[AGG] final (iter {T}) — mixEV mean={ev_mean[-1]:+.4f} ± {ev_std[-1]:.4f} | "
          f"NashConv mean={nc_mean[-1]:.4f} ± {nc_std[-1]:.4f}")

    # Plots
    if not args.no_plots:
        its = np.arange(1, T+1)
        plt.figure()
        plt.plot(its, ev_mean, label="E[P1 payoff] (mean)")
        plt.fill_between(its, ev_mean-ev_std, ev_mean+ev_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – P1 EV (mixture) — mean ± std")

        plt.figure()
        plt.plot(its, nc_mean, label="NashConv (mean)")
        plt.fill_between(its, nc_mean-nc_std, nc_mean+nc_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Exploitability (NashConv) — mean ± std")
        plt.show()

if __name__ == "__main__":
    main()
