import argparse, time, csv, os, random, math, sys
import numpy as np
import matplotlib.pyplot as plt

# --------------------------- Args ---------------------------
def parse_args():
    p = argparse.ArgumentParser("A-PSRO (PPO BR) for Kuhn Poker — multi-seed")
    p.add_argument("--iters", type=int, default=40)
    p.add_argument("--meta_loops", type=int, default=200)
    p.add_argument("--eta", type=float, default=0.25)
    p.add_argument("--eta_sched", choices=["const","sqrt","harmonic"], default="harmonic")
    p.add_argument("--kmax", type=int, default=0, help="0=unbounded; >0 cap per-player pool (evict least-mass)")
    p.add_argument("--adv_thresh", type=float, default=0.0, help="skip adding if max advantage <= this")

    # PPO
    p.add_argument("--ppo_rollouts", type=int, default=4000)
    p.add_argument("--ppo_epochs", type=int, default=10)
    p.add_argument("--ppo_batch", type=int, default=512)
    p.add_argument("--ppo_lr", type=float, default=3e-4)
    p.add_argument("--clip", type=float, default=0.2)
    p.add_argument("--ent_beta", type=float, default=1e-3)
    p.add_argument("--gamma", type=float, default=1.0)
    p.add_argument("--gae_lambda", type=float, default=0.95)
    p.add_argument("--max_grad_norm", type=float, default=1.0)

    # Numerics/I-O
    p.add_argument("--prob_eps", type=float, default=1e-6)
    p.add_argument("--logit_cap", type=float, default=50.0)
    p.add_argument("--outdir", type=str, default=".")
    p.add_argument("--csv_base", type=str, default="apsro")
    p.add_argument("--seeds", type=str, default="0,1,2,3,4")
    p.add_argument("--no_plots", action="store_true")
    return p.parse_args()

# --------------------------- Utils ---------------------------
def _seed_everything(s:int):
    random.seed(s); np.random.seed(s)

def _mem_mb():
    try:
        import psutil
        return psutil.Process().memory_info().rss / (1024**2), "rss"
    except Exception:
        try:
            import resource
            if sys.platform=="darwin":
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2), "ru_maxrss"
            else:
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0, "ru_maxrss"
        except Exception:
            return float("nan"), "n/a"

def _parse_seeds(s:str):
    out=[int(t.strip()) for t in s.split(",") if t.strip()!='']
    return out if out else [0,1,2,3,4]

# --------------------------- Game basics ---------------------------
CARDS=[0,1,2]
def ev_p1_vs(p1,p2):
    b1,c1=p1[:3],p1[3:]; c2,b2=p2[:3],p2[3:]
    ev=0.0
    for c in CARDS:
        for d in CARDS:
            if c==d: continue
            s4=2.0 if c>d else -2.0; s2=1.0 if c>d else -1.0
            ev += b1[c]*(c2[d]*s4 + (1-c2[d])*1.0) + (1-b1[c])*(b2[d]*(c1[c]*s4 + (1-c1[c])*(-1.0)) + (1-b2[d])*s2)
    return ev/6.0

def enumerate_pure_policies_p1():
    out=[]
    for mask in range(64):
        b=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]; c=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(b+c, dtype=np.float64))
    return out
def enumerate_pure_policies_p2():
    out=[]
    for mask in range(64):
        c=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]; b=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(c+b, dtype=np.float64))
    return out
PURE1=enumerate_pure_policies_p1(); PURE2=enumerate_pure_policies_p2()

# --------------------------- PPO (tabular) ---------------------------
class TabularPPO:
    def __init__(self, role, args, init_probs=None):
        self.args=args; self.role=role
        if init_probs is None:
            self.theta=np.zeros(6, dtype=np.float64)
        else:
            p=np.clip(init_probs, args.prob_eps, 1.0-args.prob_eps); self.theta=np.log(p/(1.0-p))
        self.v=np.zeros(6, dtype=np.float64)
        self.m_t=np.zeros_like(self.theta); self.v_t=np.zeros_like(self.theta)
        self.m_v=np.zeros_like(self.v);     self.v_v=np.zeros_like(self.v)
        self.t_step=0
    def _sigm(self,x): return 1.0/(1.0+np.exp(-np.clip(x, -self.args.logit_cap, self.args.logit_cap)))
    def _logp(self,idx,a):
        p=np.clip(self._sigm(self.theta[idx]), self.args.prob_eps, 1.0-self.args.prob_eps)
        return math.log(p) if a==1 else math.log(1.0-p)
    def _entropy(self,idx):
        p=np.clip(self._sigm(self.theta[idx]), self.args.prob_eps, 1.0-self.args.prob_eps)
        return -(p*math.log(p) + (1-p)*math.log(1-p))
    def _adam_update(self, param, grad, m, v, lr, b1=0.9, b2=0.999, eps=1e-8, max_norm=None):
        self.t_step+=1; m[:]=b1*m + (1-b1)*grad; v[:]=b2*v + (1-b2)*(grad*grad)
        mhat=m/(1-b1**self.t_step); vhat=v/(1-b2**self.t_step)
        step=lr*mhat/(np.sqrt(vhat)+eps)
        if max_norm is not None:
            gnorm=np.linalg.norm(step)
            if gnorm>max_norm and gnorm>0: step*= (max_norm/gnorm)
        param[:]=param - step
    def ppo_update(self,batch):
        N=len(batch["idx"]); 
        if N==0: return
        a=self.args
        idx=np.array(batch["idx"]); act=np.array(batch["act"])
        logp_old=np.array(batch["logp_old"]); ret=np.array(batch["ret"])
        adv=np.array(batch["adv"]); adv=(adv-adv.mean())/(adv.std()+1e-8)
        order=np.arange(N)
        for _ in range(a.ppo_epochs):
            np.random.shuffle(order)
            for j in range(0, N, a.ppo_batch):
                jj=order[j:j+a.ppo_batch]
                cur_logp=[]; entropy=[]
                for k in jj:
                    cur_logp.append(self._logp(idx[k], act[k])); entropy.append(self._entropy(idx[k]))
                cur_logp=np.array(cur_logp); entropy=np.array(entropy)
                ratio=np.exp(cur_logp - logp_old[jj]); surr1=ratio*adv[jj]; surr2=np.clip(ratio, 1.0-a.clip, 1.0+a.clip)*adv[jj]
                _ = -np.mean(np.minimum(surr1,surr2)) - a.ent_beta*np.mean(entropy)
                p=np.clip(self._sigm(self.theta[idx[jj]]), a.prob_eps, 1.0-a.prob_eps)
                dlogp_dtheta=(act[jj]-p); grad_ratio=np.exp(cur_logp - logp_old[jj]) * dlogp_dtheta
                grad_theta=-(1.0/len(jj))*((adv[jj])*grad_ratio)
                dHdp=-np.log(p/(1.0-p)); dpdtheta=p*(1.0-p)
                grad_theta+=-(a.ent_beta/len(jj))*(dHdp*dpdtheta)
                g_theta=np.zeros_like(self.theta)
                for kk,g in zip(idx[jj], grad_theta): g_theta[kk]+=g
                v_pred=self.v[idx[jj]]
                g_v=np.zeros_like(self.v)
                for kk,g in zip(idx[jj], (v_pred-ret[jj])/len(jj)): g_v[kk]+=g
                self._adam_update(self.theta, g_theta, self.m_t, self.v_t, lr=a.ppo_lr, max_norm=a.max_grad_norm)
                self._adam_update(self.v,     g_v,     self.m_v, self.v_v, lr=a.ppo_lr*0.5, max_norm=a.max_grad_norm)
    def probs(self):
        p=1.0/(1.0+np.exp(-np.clip(self.theta, -self.args.logit_cap, self.args.logit_cap)))
        return np.clip(p, self.args.prob_eps, 1.0-self.args.prob_eps)

# --------------------------- Episodes ---------------------------
def draw_cards():
    a=random.randrange(3); b=random.randrange(2); 
    if b>=a: b+=1
    return a,b
def simulate_episode(learner_role, theta_probs, opp_probs):
    c1,c2=draw_cards(); decs=[]
    if learner_role==0:
        p_b1=theta_probs[0+c1]; a_b1=1 if random.random()<p_b1 else 0; decs.append((0+c1,a_b1))
        if a_b1==1:
            p_call2=opp_probs[0+c2]; a_call2=1 if random.random()<p_call2 else 0
            r=(2.0 if c1>c2 else -2.0) if a_call2==1 else 1.0; return r,decs
        else:
            p_b2=opp_probs[3+c2]; a_b2=1 if random.random()<p_b2 else 0
            if a_b2==1:
                p_call1=theta_probs[3+c1]; a_call1=1 if random.random()<p_call1 else 0; decs.append((3+c1,a_call1))
                r=(2.0 if c1>c2 else -2.0) if a_call1==1 else -1.0; return r,decs
            else:
                r=1.0 if c1>c2 else -1.0; return r,decs
    else:
        p_b1=opp_probs[0+c1]; a_b1=1 if random.random()<p_b1 else 0
        if a_b1==1:
            p_call2=theta_probs[0+c2]; a_call2=1 if random.random()<p_call2 else 0; decs.append((0+c2,a_call2))
            r_p1=(2.0 if c1>c2 else -2.0) if a_call2==1 else 1.0; return -r_p1,decs
        else:
            p_b2=theta_probs[3+c2]; a_b2=1 if random.random()<p_b2 else 0; decs.append((3+c2,a_b2))
            if a_b2==1:
                p_call1=opp_probs[3+c1]; a_call1=1 if random.random()<p_call1 else 0
                r_p1=(2.0 if c1>c2 else -2.0) if a_call1==1 else -1.0; return -r_p1,decs
            else:
                r_p1=1.0 if c1>c2 else -1.0; return -r_p1,decs

# --------------------------- Runner ---------------------------
class APSRORunner:
    def __init__(self, args, seed:int):
        self.args=args; self.seed=seed; self._setup_state()
    def _setup_state(self):
        _seed_everything(self.seed)
        self.Z=[ [np.array([0.5]*6, dtype=np.float64)], [np.array([0.5]*6, dtype=np.float64)] ]
        self.LOGS=[ np.zeros(1, dtype=np.float64), np.zeros(1, dtype=np.float64) ]
    def ev_matrix(self):
        K1,K2=len(self.Z[0]), len(self.Z[1]); M=np.zeros((K1,K2), dtype=np.float64)
        for i in range(K1):
            for j in range(K2):
                M[i,j]=ev_p1_vs(self.Z[0][i], self.Z[1][j])
        return M
    def softmax_np(self,x):
        x=np.clip(x, -self.args.logit_cap, self.args.logit_cap); z=x-x.max(); e=np.exp(z)
        return e/(e.sum()+1e-12)
    def sigma(self): return [self.softmax_np(self.LOGS[0]), self.softmax_np(self.LOGS[1])]
    def eta_t(self,t):
        s=self.args.eta_sched; eta0=self.args.eta
        if s=="const": return eta0
        if s=="sqrt": return eta0/max(1.0, math.sqrt(t))
        if s=="harmonic": return eta0/(1.0+0.5*t)
        return eta0
    def mwu_meta(self,it):
        eta=self.eta_t(it)
        for _ in range(self.args.meta_loops):
            s1,s2=self.sigma(); M=self.ev_matrix(); v0=M @ s2; v1=-(s1 @ M)
            l0=np.log(s1+1e-12)+eta*v0; l0-=l0.max(); self.LOGS[0]=np.exp(l0); self.LOGS[0]/=self.LOGS[0].sum()
            l1=np.log(s2+1e-12)+eta*v1; l1-=l1.max(); self.LOGS[1]=np.exp(l1); self.LOGS[1]/=self.LOGS[1].sum()
        return self.softmax_np(self.LOGS[0]), self.softmax_np(self.LOGS[1])
    def nashconv(self, s1, s2):
        M=self.ev_matrix()
        val=float(s1 @ M @ s2)
        mix_p2=sum(s2[j]*self.Z[1][j] for j in range(len(self.Z[1])))
        mix_p1=sum(s1[i]*self.Z[0][i] for i in range(len(self.Z[0])))
        br1=max(ev_p1_vs(pi, mix_p2) for pi in PURE1)
        br2min=min(ev_p1_vs(mix_p1, pj) for pj in PURE2)
        return max(0.0, br1 - br2min), val
    def ensure_match_mix(self,role,mix):
        K=len(self.Z[1-role]); m=np.asarray(mix, dtype=np.float64)
        if len(m)!=K: m=np.pad(m, (0,max(0,K-len(m))))[:K]
        s=m.sum()
        if not np.isfinite(s) or s<=0: return np.ones(K)/K
        return m/s
    def collect_br_data(self,role,sigma_opp,rollouts,init_probs=None):
        sigma_opp=self.ensure_match_mix(role, sigma_opp)
        data={"idx":[], "act":[], "logp_old":[], "ret":[], "adv":[]}
        agent=TabularPPO(role, self.args, init_probs=init_probs)
        for _ in range(rollouts):
            j=np.random.choice(len(self.Z[1-role]), p=sigma_opp)
            opp=self.Z[1-role][j]; p_now=agent.probs()
            r,decs=simulate_episode(role, p_now, opp)
            for (idx,a) in decs:
                logp=agent._logp(idx,a); v_pred=agent.v[idx]; adv=r - v_pred
                data["idx"].append(idx); data["act"].append(a); data["logp_old"].append(logp)
                data["ret"].append(r);   data["adv"].append(adv)
        return agent,data
    def train_br_with_ppo(self,role,sigma_opp,init_probs=None):
        agent,batch=self.collect_br_data(role, sigma_opp, self.args.ppo_rollouts, init_probs)
        if len(batch["idx"])==0: return np.full(6,0.5)
        agent.ppo_update(batch); return agent.probs()
    def maybe_evict_least_mass(self,p,s):
        if self.args.kmax<=0 or len(self.Z[p])<self.args.kmax: return
        idx=int(np.argmin(s)); self.Z[p].pop(idx); self.LOGS[p]=np.delete(self.LOGS[p], idx, axis=0)
    def add_anchor(self,p,vec):
        vec=np.clip(vec, self.args.prob_eps, 1.0-self.args.prob_eps)
        self.Z[p].append(vec.copy())
        eps=1e-6
        self.LOGS[p]=np.concatenate([self.LOGS[p]*(1-eps), np.array([eps],dtype=np.float64)],0)
        self.LOGS[p]/=self.LOGS[p].sum()
    def compute_advantages(self,s1,s2):
        M=self.ev_matrix(); row_vals=M @ s2; col_vals=(s1 @ M); val=float(s1 @ M @ s2)
        A1=row_vals - val; A2=val - col_vals
        return A1,A2,val
    def run(self, csv_path:str, header=True):
        os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)
        if header: print(f"[A-PSRO] seed={self.seed} | PPO LR={self.args.ppo_lr}")
        hist_nc,hist_val,hist_dt,hist_mem,hist_n1,hist_n2=[],[],[],[],[],[]
        mem_type_rec="n/a"
        with open(csv_path,"w",newline="") as fcsv:
            w=csv.writer(fcsv)
            w.writerow(["iter","timestamp","n_strats_p1","n_strats_p2","time_sec","mem_mb","mem_type","nashconv","mix_ev_p1"])
            for it in range(1, self.args.iters+1):
                t0=time.time()
                s1,s2=self.mwu_meta(it)
                A1,A2,_=self.compute_advantages(s1,s2)
                i1=int(np.argmax(A1)); i2=int(np.argmax(A2))
                if A1[i1] > self.args.adv_thresh:
                    self.maybe_evict_least_mass(0, s1)
                    br1=self.train_br_with_ppo(0, s2, init_probs=self.Z[0][i1].copy()); self.add_anchor(0, br1)
                s1,s2=self.mwu_meta(it)
                if A2[i2] > self.args.adv_thresh:
                    self.maybe_evict_least_mass(1, s2)
                    br2=self.train_br_with_ppo(1, s1, init_probs=self.Z[1][i2].copy()); self.add_anchor(1, br2)
                s1,s2=self.mwu_meta(it)
                nc,val=self.nashconv(s1,s2)
                dt=time.time()-t0; mem,mtype=_mem_mb(); mem_type_rec=mtype
                hist_nc.append(nc); hist_val.append(val); hist_dt.append(dt); hist_mem.append(mem)
                hist_n1.append(len(self.Z[0])); hist_n2.append(len(self.Z[1]))
                print(f"[A-PSRO] seed={self.seed} iter {it}/{self.args.iters} | P1={len(self.Z[0])} P2={len(self.Z[1])} | "
                      f"NashConv={nc:.5f} val={val:+.5f} | {dt:.2f}s mem={mem:.1f}MB")
                w.writerow([it, time.strftime("%Y-%m-%d %H:%M:%S"),
                            len(self.Z[0]), len(self.Z[1]),
                            f"{dt:.3f}", f"{mem:.2f}", mtype,
                            f"{nc:.6f}", f"{val:+.6f}"])
                fcsv.flush()
        return {"nc":np.array(hist_nc), "val":np.array(hist_val),
                "dt":np.array(hist_dt), "mem":np.array(hist_mem),
                "n1":np.array(hist_n1), "n2":np.array(hist_n2),
                "mem_type":mem_type_rec}

# --------------------------- Aggregate ---------------------------
def _save_meanstd_csv(path,T,mem_type,dt_mean,dt_std,mem_mean,mem_std,ev_mean,ev_std,nc_mean,nc_std,n1_mean,n1_std,n2_mean,n2_std):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path,"w",newline="") as f:
        w=csv.writer(f)
        w.writerow(["iter","dt_mean","dt_std","mem_mb_mean","mem_mb_std","mix_ev_mean","mix_ev_std",
                    "nashconv_mean","nashconv_std","n_strats_p1_mean","n_strats_p1_std",
                    "n_strats_p2_mean","n_strats_p2_std","mem_type"])
        for i in range(T):
            w.writerow([i+1, f"{dt_mean[i]:.6f}", f"{dt_std[i]:.6f}",
                        f"{mem_mean[i]:.3f}", f"{mem_std[i]:.3f}",
                        f"{ev_mean[i]:+.6f}", f"{ev_std[i]:.6f}",
                        f"{nc_mean[i]:.6f}", f"{nc_std[i]:.6f}",
                        f"{n1_mean[i]:.3f}", f"{n1_std[i]:.3f}",
                        f"{n2_mean[i]:.3f}", f"{n2_std[i]:.3f}", mem_type])

def main():
    args=parse_args(); seeds=_parse_seeds(args.seeds)
    os.makedirs(args.outdir, exist_ok=True)
    per=[]
    for s in seeds:
        r=APSRORunner(args, seed=s)
        csv_path=os.path.join(args.outdir, f"{args.csv_base}_seed{s}.csv")
        per.append(r.run(csv_path, header=(s==seeds[0])))
    T=args.iters
    for rec in per: assert len(rec["nc"])==T and len(rec["val"])==T
    nc_stack=np.stack([r["nc"] for r in per],0); ev_stack=np.stack([r["val"] for r in per],0)
    dt_stack=np.stack([r["dt"] for r in per],0); mem_stack=np.stack([r["mem"] for r in per],0)
    n1_stack=np.stack([r["n1"] for r in per],0); n2_stack=np.stack([r["n2"] for r in per],0)
    nc_mean,nc_std=nc_stack.mean(0), nc_stack.std(0,ddof=1)
    ev_mean,ev_std=ev_stack.mean(0), ev_stack.std(0,ddof=1)
    dt_mean,dt_std=dt_stack.mean(0), dt_stack.std(0,ddof=1)
    mem_mean,mem_std=mem_stack.mean(0), mem_stack.std(0,ddof=1)
    n1_mean,n1_std=n1_stack.mean(0), n1_stack.std(0,ddof=1)
    n2_mean,n2_std=n2_stack.mean(0), n2_stack.std(0,ddof=1)
    agg_csv=os.path.join(args.outdir, f"{args.csv_base}_meanstd.csv")
    _save_meanstd_csv(agg_csv,T, per[0]["mem_type"], dt_mean,dt_std,mem_mean,mem_std,ev_mean,ev_std,nc_mean,nc_std,n1_mean,n1_std,n2_mean,n2_std)
    print(f"[AGG] wrote mean/std CSV: {agg_csv}")
    print(f"[AGG] final (iter {T}) — EV mean={ev_mean[-1]:+.4f} ± {ev_std[-1]:.4f} | "
          f"NashConv mean={nc_mean[-1]:.4f} ± {nc_std[-1]:.4f} | "
          f"P1 strats ~{n1_mean[-1]:.2f}, P2 strats ~{n2_mean[-1]:.2f}")
    if not args.no_plots:
        its=np.arange(1,T+1)
        plt.figure(); plt.plot(its, nc_mean, label="NashConv (mean)")
        plt.fill_between(its, nc_mean-nc_std, nc_mean+nc_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Exploitability (A-PSRO) — mean ± std")
        plt.figure(); plt.plot(its, ev_mean, label="E[P1 payoff] (mean)")
        plt.fill_between(its, ev_mean-ev_std, ev_mean+ev_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Mixture Value (A-PSRO) — mean ± std")
        plt.show()

if __name__=="__main__":
    main()
