import argparse, time, csv, os, random, math, sys
import numpy as np
import matplotlib.pyplot as plt

# --------------------------- Args ---------------------------
def parse_args():
    p = argparse.ArgumentParser("PSRO for Kuhn Poker — multi-seed")
    # Outer loop
    p.add_argument("--iters", type=int, default=40)
    p.add_argument("--meta_loops", type=int, default=200)
    p.add_argument("--eta", type=float, default=0.25)
    p.add_argument("--eta_sched", choices=["const","sqrt","harmonic"], default="harmonic")
    p.add_argument("--kmax", type=int, default=0, help="0=unbounded; >0 cap per-player pool (least-mass eviction)")

    # PPO BR hyperparams (EXPLICIT LR)
    p.add_argument("--ppo_rollouts", type=int, default=4000, help="episodes per BR training")
    p.add_argument("--ppo_epochs", type=int, default=10)
    p.add_argument("--ppo_batch", type=int, default=512)
    p.add_argument("--ppo_lr", type=float, default=3e-4, help="LEARNING RATE for PPO Adam optimizer")
    p.add_argument("--clip", type=float, default=0.2)
    p.add_argument("--ent_beta", type=float, default=1e-3)
    p.add_argument("--gamma", type=float, default=1.0)
    p.add_argument("--gae_lambda", type=float, default=0.95)  # kept for symmetry; gamma=1 episodic
    p.add_argument("--max_grad_norm", type=float, default=1.0)

    # Numerics
    p.add_argument("--prob_eps", type=float, default=1e-6)
    p.add_argument("--logit_cap", type=float, default=50.0)

    # I/O
    p.add_argument("--outdir", type=str, default=".")
    p.add_argument("--csv_base", type=str, default="psro_kuhn_ppo_br")
    p.add_argument("--seeds", type=str, default="0,1,2,3,4",
                   help="Comma-separated list of seeds, default '0,1,2'")
    p.add_argument("--no_plots", action="store_true")
    return p.parse_args()

# --------------------------- Utils ---------------------------
def _seed_everything(s: int):
    random.seed(s); np.random.seed(s)

def _mem_mb():
    try:
        import psutil
        return psutil.Process().memory_info().rss / (1024**2), "rss"
    except Exception:
        try:
            import resource
            if sys.platform == "darwin":
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2), "ru_maxrss"
            else:
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0, "ru_maxrss"
        except Exception:
            return float("nan"), "n/a"

def _parse_seeds(seeds_str: str):
    out=[]
    for tok in seeds_str.split(","):
        tok = tok.strip()
        if not tok: continue
        out.append(int(tok))
    return out if out else [0,1,2,3,4]

# --------------------------- Kuhn EV (exact) ---------------------------
CARDS=[0,1,2]

def ev_p1_vs(p1, p2):
    """Exact expected payoff to P1; p1,p2 are 6-length prob vectors."""
    b1, c1 = p1[:3], p1[3:]
    c2, b2 = p2[:3], p2[3:]
    ev=0.0
    for c in CARDS:
        for d in CARDS:
            if c==d: continue
            s4 = 2.0 if c>d else -2.0
            s2 = 1.0 if c>d else -1.0
            ev += b1[c]*( c2[d]*s4 + (1-c2[d])*1.0 ) + \
                  (1-b1[c])*( b2[d]*( c1[c]*s4 + (1-c1[c])*(-1.0) ) + (1-b2[d])*s2 )
    return ev/6.0

# --------------------------- Pure policy enumeration (for NashConv eval) ---------------------------
def enumerate_pure_policies_p1():
    out=[]
    for mask in range(64):
        b=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]
        c=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(b+c, dtype=np.float64))
    return out
def enumerate_pure_policies_p2():
    out=[]
    for mask in range(64):
        c=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]
        b=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(c+b, dtype=np.float64))
    return out
PURE1 = enumerate_pure_policies_p1()
PURE2 = enumerate_pure_policies_p2()

# --------------------------- PPO BR (tabular Bernoulli + tabular value) ---------------------------
class TabularPPO:
    """
    6 Bernoulli logits (policy) + 6 value scalars (baseline).
    PPO with clipped objective & entropy bonus. Optimized with Adam at --ppo_lr.
    """
    def __init__(self, role, args):
        self.args = args
        self.role = role  # 0=P1, 1=P2
        self.theta = np.zeros(6, dtype=np.float64)  # logits
        self.v     = np.zeros(6, dtype=np.float64)  # value table
        # Adam states
        self.m_t = np.zeros_like(self.theta); self.v_t = np.zeros_like(self.theta)
        self.m_v = np.zeros_like(self.v);     self.v_v = np.zeros_like(self.v)
        self.t_step = 0

    def _sigm(self, x): 
        c = np.clip(x, -self.args.logit_cap, self.args.logit_cap)
        return 1.0/(1.0+np.exp(-c))

    def _logp(self, idx, a):
        p = np.clip(self._sigm(self.theta[idx]), self.args.prob_eps, 1.0-self.args.prob_eps)
        return math.log(p) if a==1 else math.log(1.0-p)

    def _entropy(self, idx):
        p = np.clip(self._sigm(self.theta[idx]), self.args.prob_eps, 1.0-self.args.prob_eps)
        return -(p*math.log(p) + (1-p)*math.log(1-p))

    def _adam_update(self, param, grad, m, v, lr, b1=0.9, b2=0.999, eps=1e-8, max_norm=None):
        self.t_step += 1
        m[:] = b1*m + (1-b1)*grad
        v[:] = b2*v + (1-b2)*(grad*grad)
        mhat = m/(1-b1**self.t_step)
        vhat = v/(1-b2**self.t_step)
        step = lr * mhat / (np.sqrt(vhat) + eps)
        if max_norm is not None:
            gnorm = np.linalg.norm(step)
            if gnorm > max_norm and gnorm > 0:
                step = step * (max_norm / gnorm)
        param[:] = param - step

    def ppo_update(self, batch):
        N=len(batch["idx"])
        if N==0: return
        args=self.args
        idx=np.array(batch["idx"], dtype=np.int64)
        act=np.array(batch["act"], dtype=np.int64)
        logp_old=np.array(batch["logp_old"], dtype=np.float64)
        ret=np.array(batch["ret"], dtype=np.float64)
        adv=np.array(batch["adv"], dtype=np.float64)

        adv = (adv - adv.mean()) / (adv.std()+1e-8)

        order=np.arange(N)
        for _ in range(args.ppo_epochs):
            np.random.shuffle(order)
            for j in range(0, N, args.ppo_batch):
                jj=order[j:j+args.ppo_batch]
                cur_logp=[]; entropy=[]
                for k in jj:
                    cur_logp.append(self._logp(idx[k], act[k]))
                    entropy.append(self._entropy(idx[k]))
                cur_logp=np.array(cur_logp); entropy=np.array(entropy)
                ratio=np.exp(cur_logp - logp_old[jj])
                surr1 = ratio * adv[jj]
                surr2 = np.clip(ratio, 1.0-args.clip, 1.0+args.clip) * adv[jj]
                policy_loss = -np.mean(np.minimum(surr1, surr2)) - args.ent_beta*np.mean(entropy)

                v_pred = self.v[idx[jj]]
                value_loss = 0.5 * np.mean((ret[jj] - v_pred)**2)

                p = np.clip(self._sigm(self.theta[idx[jj]]), args.prob_eps, 1.0-args.prob_eps)
                dlogp_dtheta = (act[jj] - p)
                grad_ratio = np.exp(cur_logp - logp_old[jj]) * dlogp_dtheta
                grad_theta = -(1.0/len(jj)) * ( ( (adv[jj]) * grad_ratio ) )

                dHdp = -np.log(p/(1.0-p))
                dpdtheta = p*(1.0-p)
                grad_theta += -(args.ent_beta/len(jj)) * (dHdp * dpdtheta)

                g_theta=np.zeros_like(self.theta)
                for kk, g in zip(idx[jj], grad_theta):
                    g_theta[kk] += g

                g_v = np.zeros_like(self.v)
                for kk, g in zip(idx[jj], (v_pred - ret[jj]) / len(jj)):
                    g_v[kk] += g

                self._adam_update(self.theta, g_theta, self.m_t, self.v_t, lr=args.ppo_lr, max_norm=args.max_grad_norm)
                self._adam_update(self.v,     g_v,     self.m_v, self.v_v, lr=args.ppo_lr*0.5, max_norm=args.max_grad_norm)

    def probs(self):
        p = 1.0/(1.0 + np.exp(-np.clip(self.theta, -self.args.logit_cap, self.args.logit_cap)))
        return np.clip(p, self.args.prob_eps, 1.0-self.args.prob_eps)

# --------------------------- Rollout simulator (episodes) ---------------------------
def draw_cards():
    a=random.randrange(3)
    b=random.randrange(2)
    if b>=a: b+=1
    return a,b  # distinct

def simulate_episode(learner_role, theta_probs, opp_probs):
    """Returns (reward_for_learner, decisions_list[(info_idx, action)])."""
    c1,c2 = draw_cards()
    decs=[]
    if learner_role==0:
        p_b1 = theta_probs[0+ c1]
        a_b1 = 1 if random.random()<p_b1 else 0
        decs.append((0+c1, a_b1))
        if a_b1==1:
            p_call2 = opp_probs[0+ c2]
            a_call2 = 1 if random.random()<p_call2 else 0
            r = (2.0 if c1>c2 else -2.0) if a_call2==1 else 1.0
            return r, decs
        else:
            p_b2 = opp_probs[3+ c2]
            a_b2 = 1 if random.random()<p_b2 else 0
            if a_b2==1:
                p_call1 = theta_probs[3+ c1]
                a_call1 = 1 if random.random()<p_call1 else 0
                decs.append((3+c1, a_call1))
                r = (2.0 if c1>c2 else -2.0) if a_call1==1 else -1.0
                return r, decs
            else:
                r = 1.0 if c1>c2 else -1.0
                return r, decs
    else:
        p_b1 = opp_probs[0+ c1]
        a_b1 = 1 if random.random()<p_b1 else 0
        if a_b1==1:
            p_call2 = theta_probs[0+ c2]
            a_call2 = 1 if random.random()<p_call2 else 0
            decs.append((0+c2, a_call2))
            r_p1 = (2.0 if c1>c2 else -2.0) if a_call2==1 else 1.0
            return -r_p1, decs
        else:
            p_b2 = theta_probs[3+ c2]
            a_b2 = 1 if random.random()<p_b2 else 0
            decs.append((3+c2, a_b2))
            if a_b2==1:
                p_call1 = opp_probs[3+ c1]
                a_call1 = 1 if random.random()<p_call1 else 0
                r_p1 = (2.0 if c1>c2 else -2.0) if a_call1==1 else -1.0
                return -r_p1, decs
            else:
                r_p1 = 1.0 if c1>c2 else -1.0
                return -r_p1, decs

# --------------------------- Runner ---------------------------
class PSRORunner:
    def __init__(self, args, seed: int):
        self.args = args
        self.seed = seed
        self._setup_state()

    def _setup_state(self):
        _seed_everything(self.seed)
        # Populations: lists of 6-prob vectors
        self.Z = [ [np.array([0.5]*6, dtype=np.float64)], [np.array([0.5]*6, dtype=np.float64)] ]
        # "LOGS" containers as in your script; sigma_list applies softmax to them
        self.LOGS = [ np.zeros(1, dtype=np.float64), np.zeros(1, dtype=np.float64) ]

    # ---- restricted game ----
    def ev_matrix(self):
        K1,K2=len(self.Z[0]), len(self.Z[1])
        M=np.zeros((K1,K2), dtype=np.float64)
        for i in range(K1):
            for j in range(K2):
                M[i,j]=ev_p1_vs(self.Z[0][i], self.Z[1][j])
        return M

    # ---- meta-solver (MWU-ish, as in your code) ----
    def softmax_np(self, x):
        cap=float(self.args.logit_cap)
        x=np.clip(x, -cap, cap)
        z=x-x.max(); e=np.exp(z)
        return e/(e.sum()+1e-12)

    def sigma_list(self):
        return [self.softmax_np(self.LOGS[0]), self.softmax_np(self.LOGS[1])]

    def eta_t(self, t):
        s=self.args.eta_sched; eta0=self.args.eta
        if s=="const": return eta0
        if s=="sqrt":  return eta0/max(1.0, math.sqrt(t))
        if s=="harmonic": return eta0/(1.0+0.5*t)
        return eta0

    def meta_solve(self, it):
        eta=self.eta_t(it)
        for _ in range(self.args.meta_loops):
            s1,s2=self.sigma_list()
            M=self.ev_matrix()
            v0=M @ s2
            v1=-(s1 @ M)
            # your update style (note: LOGS holds prob-like weights after this step)
            l0=np.log(s1+1e-12)+eta*v0; l0-=l0.max(); self.LOGS[0]=np.exp(l0); self.LOGS[0]/=self.LOGS[0].sum()
            l1=np.log(s2+1e-12)+eta*v1; l1-=l1.max(); self.LOGS[1]=np.exp(l1); self.LOGS[1]/=self.LOGS[1].sum()

    # ---- nashconv ----
    def nashconv(self):
        s1,s2 = self.sigma_list()
        M = self.ev_matrix()
        val = float(s1 @ M @ s2)
        mix_p2 = sum(s2[j]*self.Z[1][j] for j in range(len(self.Z[1])))
        mix_p1 = sum(s1[i]*self.Z[0][i] for i in range(len(self.Z[0])))
        br1 = max(ev_p1_vs(pi, mix_p2) for pi in PURE1)
        br2min = min(ev_p1_vs(mix_p1, pj) for pj in PURE2)
        return max(0.0, br1 - br2min), val

    # ---- population ops ----
    def maybe_evict_least_mass(self, p):
        if self.args.kmax<=0: return
        if len(self.Z[p])<self.args.kmax: return
        s = self.sigma_list()[p]
        idx=int(np.argmin(s))
        self.Z[p].pop(idx); self.LOGS[p]=np.delete(self.LOGS[p], idx, axis=0)

    def add_anchor(self, p, vec):
        vec=np.clip(vec, self.args.prob_eps, 1.0-self.args.prob_eps)
        self.Z[p].append(vec.copy())
        eps=1e-6
        self.LOGS[p]=np.concatenate([self.LOGS[p]*(1-eps), np.array([eps],dtype=np.float64)], 0)
        self.LOGS[p]/=self.LOGS[p].sum()

    # ---- robust mix guard ----
    def ensure_match_mix(self, learner_role, mix):
        K = len(self.Z[1-learner_role])
        m = np.asarray(mix, dtype=np.float64).copy()
        if len(m) != K:
            if len(m) < K:
                m = np.concatenate([m, np.zeros(K - len(m), dtype=np.float64)], 0)
            else:
                m = m[:K]
        s = m.sum()
        if not np.isfinite(s) or s <= 0:
            m = np.ones(K, dtype=np.float64) / K
        else:
            m = m / s
        return m

    # ---- PPO BR data & training ----
    def collect_br_data(self, learner_role, sigma_opp, rollouts):
        sigma_opp = self.ensure_match_mix(learner_role, sigma_opp)
        data = {"idx":[], "act":[], "logp_old":[], "ret":[], "adv":[]}
        agent = TabularPPO(learner_role, self.args)
        for _ in range(rollouts):
            j = np.random.choice(len(self.Z[1-learner_role]), p=sigma_opp)
            opp = self.Z[1-learner_role][j]
            p_now = agent.probs()
            r, decs = simulate_episode(learner_role, p_now, opp)
            for (idx, a) in decs:
                logp = agent._logp(idx, a)
                v_pred = agent.v[idx]
                adv = r - v_pred
                data["idx"].append(idx)
                data["act"].append(a)
                data["logp_old"].append(logp)
                data["ret"].append(r)
                data["adv"].append(adv)
        return agent, data

    def train_br_with_ppo(self, role, sigma_opp):
        sigma_opp = self.ensure_match_mix(role, sigma_opp)
        agent, batch = self.collect_br_data(role, sigma_opp, self.args.ppo_rollouts)
        if len(batch["idx"]) == 0:
            return np.full(6, 0.5, dtype=np.float64)
        agent.ppo_update(batch)
        return agent.probs()

    # ---- run one seed ----
    def run(self, csv_path: str, header=True):
        os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)
        if header:
            print(f"[PSRO] seed={self.seed} | PPO LR={self.args.ppo_lr}")

        hist_nc, hist_val = [], []
        hist_dt, hist_mem = [], []
        hist_n1, hist_n2  = [], []
        mem_type_rec = "n/a"

        with open(csv_path, "w", newline="") as fcsv:
            w=csv.writer(fcsv)
            w.writerow(["iter","timestamp","n_strats_p1","n_strats_p2","time_sec","mem_mb","mem_type","nashconv","mix_ev_p1"])

            for it in range(1, self.args.iters+1):
                t0=time.time()

                # 1) Solve the restricted game
                self.meta_solve(it)
                s1,s2=self.sigma_list()

                # 2) Add PPO BRs (evict least-mass if capped)
                self.maybe_evict_least_mass(0)
                br1 = self.train_br_with_ppo(role=0, sigma_opp=s2)
                self.add_anchor(0, br1)

                # Re-solve after P1 change
                self.meta_solve(it)
                s1,s2=self.sigma_list()

                self.maybe_evict_least_mass(1)
                br2 = self.train_br_with_ppo(role=1, sigma_opp=s1)
                self.add_anchor(1, br2)

                # 3) Metrics
                nc, val = self.nashconv()
                dt=time.time()-t0; mem, mtype = _mem_mb()
                mem_type_rec = mtype

                hist_nc.append(nc); hist_val.append(val)
                hist_dt.append(dt); hist_mem.append(mem)
                hist_n1.append(len(self.Z[0])); hist_n2.append(len(self.Z[1]))

                print(f"[PSRO] seed={self.seed} iter {it}/{self.args.iters} | P1={len(self.Z[0])} P2={len(self.Z[1])} "
                      f"| NashConv={nc:.5f} val={val:+.5f} | {dt:.2f}s mem={mem:.1f}MB")

                w.writerow([it, time.strftime("%Y-%m-%d %H:%M:%S"),
                            len(self.Z[0]), len(self.Z[1]), f"{dt:.3f}", f"{mem:.2f}", mtype,
                            f"{nc:.6f}", f"{val:+.6f}"])
                fcsv.flush()

        return {
            "nc":  np.array(hist_nc, dtype=np.float64),
            "val": np.array(hist_val, dtype=np.float64),
            "dt":  np.array(hist_dt, dtype=np.float64),
            "mem": np.array(hist_mem, dtype=np.float64),
            "n1":  np.array(hist_n1, dtype=np.float64),
            "n2":  np.array(hist_n2, dtype=np.float64),
            "mem_type": mem_type_rec
        }

# --------------------------- Aggregate helpers ---------------------------
def _save_meanstd_csv(path, T, mem_type,
                      dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std,
                      n1_mean, n1_std, n2_mean, n2_std):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "iter",
            "dt_mean","dt_std",
            "mem_mb_mean","mem_mb_std",
            "mix_ev_mean","mix_ev_std",
            "nashconv_mean","nashconv_std",
            "n_strats_p1_mean","n_strats_p1_std",
            "n_strats_p2_mean","n_strats_p2_std",
            "mem_type"
        ])
        for i in range(T):
            w.writerow([
                i+1,
                f"{dt_mean[i]:.6f}", f"{dt_std[i]:.6f}",
                f"{mem_mean[i]:.3f}", f"{mem_std[i]:.3f}",
                f"{ev_mean[i]:+.6f}", f"{ev_std[i]:.6f}",
                f"{nc_mean[i]:.6f}", f"{nc_std[i]:.6f}",
                f"{n1_mean[i]:.3f}", f"{n1_std[i]:.3f}",
                f"{n2_mean[i]:.3f}", f"{n2_std[i]:.3f}",
                mem_type
            ])

# --------------------------- Entry ---------------------------
def main():
    args = parse_args()
    seeds = _parse_seeds(args.seeds)
    os.makedirs(args.outdir, exist_ok=True)

    per_seed = []
    for s in seeds:
        runner = PSRORunner(args, seed=s)
        seed_csv = os.path.join(args.outdir, f"{args.csv_base}_seed{s}.csv")
        per_seed.append(runner.run(seed_csv, header=(s==seeds[0])))

    T = args.iters
    for rec in per_seed:
        assert len(rec["nc"])==T and len(rec["val"])==T, "All seeds must have identical iters"

    # Stack & aggregate
    nc_stack  = np.stack([rec["nc"]  for rec in per_seed], axis=0)
    val_stack = np.stack([rec["val"] for rec in per_seed], axis=0)
    dt_stack  = np.stack([rec["dt"]  for rec in per_seed], axis=0)
    mem_stack = np.stack([rec["mem"] for rec in per_seed], axis=0)
    n1_stack  = np.stack([rec["n1"]  for rec in per_seed], axis=0)
    n2_stack  = np.stack([rec["n2"]  for rec in per_seed], axis=0)

    nc_mean,  nc_std  = nc_stack.mean(axis=0),  nc_stack.std(axis=0, ddof=1)
    ev_mean,  ev_std  = val_stack.mean(axis=0), val_stack.std(axis=0, ddof=1)
    dt_mean,  dt_std  = dt_stack.mean(axis=0),  dt_stack.std(axis=0, ddof=1)
    mem_mean, mem_std = mem_stack.mean(axis=0), mem_stack.std(axis=0, ddof=1)
    n1_mean,  n1_std  = n1_stack.mean(axis=0),  n1_stack.std(axis=0, ddof=1)
    n2_mean,  n2_std  = n2_stack.mean(axis=0),  n2_stack.std(axis=0, ddof=1)

    mem_type = per_seed[0]["mem_type"]
    agg_csv = os.path.join(args.outdir, f"{args.csv_base}_meanstd.csv")
    _save_meanstd_csv(agg_csv, T, mem_type,
                      dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std,
                      n1_mean, n1_std, n2_mean, n2_std)
    print(f"[AGG] wrote mean/std CSV: {agg_csv}")
    print(f"[AGG] final (iter {T}) — EV mean={ev_mean[-1]:+.4f} ± {ev_std[-1]:.4f} | "
          f"NashConv mean={nc_mean[-1]:.4f} ± {nc_std[-1]:.4f} | "
          f"P1 strats ~{n1_mean[-1]:.2f}, P2 strats ~{n2_mean[-1]:.2f}")

    if not args.no_plots:
        its = np.arange(1, T+1)
        plt.figure()
        plt.plot(its, nc_mean, label="NashConv (mean)")
        plt.fill_between(its, nc_mean-nc_std, nc_mean+nc_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Exploitability (NashConv) — mean ± std")

        plt.figure()
        plt.plot(its, ev_mean, label="E[P1 payoff] (mean)")
        plt.fill_between(its, ev_mean-ev_std, ev_mean+ev_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Mixture Value — mean ± std")
        plt.show()

if __name__ == "__main__":
    main()
