import argparse, time, csv, os, random, math, sys
import numpy as np
import matplotlib.pyplot as plt

# --------------------------- Args ---------------------------
def parse_args():
    p = argparse.ArgumentParser("α-PSRO (PPO BR) for Kuhn Poker — multi-seed")
    # Outer loop
    p.add_argument("--iters", type=int, default=40)
    p.add_argument("--alpha", type=float, default=10.0, help="selection intensity for α-Rank")
    p.add_argument("--kmax", type=int, default=0, help="0=unbounded; >0 cap per-player pool (evict least-mass)")

    # PPO BR hyperparams (explicit LR)
    p.add_argument("--ppo_rollouts", type=int, default=4000, help="episodes per BR training")
    p.add_argument("--ppo_epochs", type=int, default=10)
    p.add_argument("--ppo_batch", type=int, default=512)
    p.add_argument("--ppo_lr", type=float, default=3e-4, help="LEARNING RATE for PPO Adam optimizer")
    p.add_argument("--clip", type=float, default=0.2)
    p.add_argument("--ent_beta", type=float, default=1e-3)
    p.add_argument("--gamma", type=float, default=1.0)
    p.add_argument("--gae_lambda", type=float, default=0.95)
    p.add_argument("--max_grad_norm", type=float, default=1.0)

    # Numerics / I/O
    p.add_argument("--prob_eps", type=float, default=1e-6)
    p.add_argument("--logit_cap", type=float, default=50.0)
    p.add_argument("--outdir", type=str, default=".")
    p.add_argument("--csv_base", type=str, default="alpha_psro")
    p.add_argument("--seeds", type=str, default="0,1,2,3,4", help="Comma-separated seeds")
    p.add_argument("--no_plots", action="store_true")
    return p.parse_args()

# --------------------------- Utils ---------------------------
def _seed_everything(s: int):
    random.seed(s); np.random.seed(s)

def _mem_mb():
    try:
        import psutil
        return psutil.Process().memory_info().rss / (1024**2), "rss"
    except Exception:
        try:
            import resource
            if sys.platform == "darwin":
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2), "ru_maxrss"
            else:
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0, "ru_maxrss"
        except Exception:
            return float("nan"), "n/a"

def _parse_seeds(seeds_str: str):
    out=[]
    for tok in seeds_str.split(","):
        tok=tok.strip()
        if tok: out.append(int(tok))
    return out if out else [0,1,2,3,4]

# --------------------------- Kuhn EV (exact) ---------------------------
CARDS=[0,1,2]
def ev_p1_vs(p1, p2):
    b1, c1 = p1[:3], p1[3:]
    c2, b2 = p2[:3], p2[3:]
    ev=0.0
    for c in CARDS:
        for d in CARDS:
            if c==d: continue
            s4 = 2.0 if c>d else -2.0
            s2 = 1.0 if c>d else -1.0
            ev += b1[c]*( c2[d]*s4 + (1-c2[d])*1.0 ) + (1-b1[c])*( b2[d]*( c1[c]*s4 + (1-c1[c])*(-1.0) ) + (1-b2[d])*s2 )
    return ev/6.0

def enumerate_pure_policies_p1():
    out=[]
    for mask in range(64):
        b=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]; c=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(b+c, dtype=np.float64))
    return out
def enumerate_pure_policies_p2():
    out=[]
    for mask in range(64):
        c=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]; b=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(c+b, dtype=np.float64))
    return out
PURE1 = enumerate_pure_policies_p1()
PURE2 = enumerate_pure_policies_p2()

# --------------------------- α-Rank meta-solver ---------------------------
def _logistic_stable(x):
    z = np.clip(x, -50.0, 50.0)
    return 1.0/(1.0+np.exp(-z))

def alpha_rank_meta(M, alpha=10.0, tol=1e-12, max_steps=200000):
    K1, K2 = M.shape
    N = K1*K2
    if K1==0 or K2==0:
        return np.ones(K1)/max(1,K1), np.ones(K2)/max(1,K2)
    if K1==1 and K2==1:
        return np.array([1.0]), np.array([1.0])

    P = np.zeros((N,N), dtype=np.float64)
    mu_row, mu_col = (0.5 if K1>1 else 0.0), (0.5 if K2>1 else 0.0)
    if K1>1 and K2>1:
        mu_row=mu_col=0.5
    elif K1>1:
        mu_row,mu_col=1.0,0.0
    elif K2>1:
        mu_row,mu_col=0.0,1.0

    for i in range(K1):
        for j in range(K2):
            s=i*K2+j
            row_sum=0.0
            ui=M[i,j]
            if K1>1 and mu_row>0.0:
                denom=K1-1
                for ip in range(K1):
                    if ip==i: continue
                    w = mu_row * _logistic_stable(alpha*(M[ip,j]-ui)) / denom
                    P[s, ip*K2+j]+=w; row_sum+=w
            if K2>1 and mu_col>0.0:
                denom=K2-1
                for jp in range(K2):
                    if jp==j: continue
                    w = mu_col * _logistic_stable(alpha*(ui - M[i,jp])) / denom
                    P[s, i*K2+jp]+=w; row_sum+=w
            P[s,s]=max(0.0, 1.0-row_sum)

    row_sums=P.sum(axis=1, keepdims=True)
    zero_rows=(row_sums[:,0]==0.0)
    if np.any(zero_rows):
        for r in np.where(zero_rows)[0]: P[r,r]=1.0
        row_sums=P.sum(axis=1, keepdims=True)
    P = P/row_sums

    eps=1e-12
    if eps>0:
        P=(1.0-eps)*P + eps*(np.ones_like(P)/N)

    pi=np.ones(N)/N
    for _ in range(max_steps):
        new = pi @ P
        if np.linalg.norm(new-pi,1) < tol:
            pi=new; break
        pi=new
    pi=pi/(pi.sum()+1e-12)

    s1=np.zeros(K1); s2=np.zeros(K2)
    for i in range(K1):
        for j in range(K2):
            s=i*K2+j; s1[i]+=pi[s]; s2[j]+=pi[s]
    s1/=s1.sum()+1e-12; s2/=s2.sum()+1e-12
    return s1, s2

# --------------------------- PPO (tabular) ---------------------------
class TabularPPO:
    def __init__(self, role, args):
        self.args=args; self.role=role
        self.theta=np.zeros(6, dtype=np.float64)
        self.v=np.zeros(6, dtype=np.float64)
        self.m_t=np.zeros_like(self.theta); self.v_t=np.zeros_like(self.theta)
        self.m_v=np.zeros_like(self.v);     self.v_v=np.zeros_like(self.v)
        self.t_step=0
    def _sigm(self,x): return 1.0/(1.0+np.exp(-np.clip(x, -self.args.logit_cap, self.args.logit_cap)))
    def _logp(self,idx,a):
        p=np.clip(self._sigm(self.theta[idx]), self.args.prob_eps, 1.0-self.args.prob_eps)
        return math.log(p) if a==1 else math.log(1.0-p)
    def _entropy(self,idx):
        p=np.clip(self._sigm(self.theta[idx]), self.args.prob_eps, 1.0-self.args.prob_eps)
        return -(p*math.log(p) + (1-p)*math.log(1-p))
    def _adam_update(self, param, grad, m, v, lr, b1=0.9, b2=0.999, eps=1e-8, max_norm=None):
        self.t_step+=1
        m[:]=b1*m + (1-b1)*grad
        v[:]=b2*v + (1-b2)*(grad*grad)
        mhat=m/(1-b1**self.t_step); vhat=v/(1-b2**self.t_step)
        step=lr*mhat/(np.sqrt(vhat)+eps)
        if max_norm is not None:
            gnorm=np.linalg.norm(step)
            if gnorm>max_norm and gnorm>0: step*= (max_norm/gnorm)
        param[:]=param - step
    def ppo_update(self, batch):
        N=len(batch["idx"]); 
        if N==0: return
        a=self.args
        idx=np.array(batch["idx"]); act=np.array(batch["act"])
        logp_old=np.array(batch["logp_old"]); ret=np.array(batch["ret"])
        adv=np.array(batch["adv"]); adv=(adv-adv.mean())/(adv.std()+1e-8)
        order=np.arange(N)
        for _ in range(a.ppo_epochs):
            np.random.shuffle(order)
            for j in range(0, N, a.ppo_batch):
                jj=order[j:j+a.ppo_batch]
                cur_logp=[]; entropy=[]
                for k in jj:
                    cur_logp.append(self._logp(idx[k], act[k])); entropy.append(self._entropy(idx[k]))
                cur_logp=np.array(cur_logp); entropy=np.array(entropy)
                ratio=np.exp(cur_logp - logp_old[jj])
                surr1=ratio*adv[jj]; surr2=np.clip(ratio, 1.0-a.clip, 1.0+a.clip)*adv[jj]
                # policy loss (used to shape gradient only)
                _ = -np.mean(np.minimum(surr1, surr2)) - a.ent_beta*np.mean(entropy)
                p=np.clip(self._sigm(self.theta[idx[jj]]), a.prob_eps, 1.0-a.prob_eps)
                dlogp_dtheta=(act[jj]-p); grad_ratio=np.exp(cur_logp - logp_old[jj]) * dlogp_dtheta
                grad_theta=-(1.0/len(jj))*((adv[jj])*grad_ratio)
                dHdp=-np.log(p/(1.0-p)); dpdtheta=p*(1.0-p)
                grad_theta+=-(a.ent_beta/len(jj))*(dHdp*dpdtheta)
                g_theta=np.zeros_like(self.theta)
                for kk,g in zip(idx[jj], grad_theta): g_theta[kk]+=g
                v_pred=self.v[idx[jj]]
                g_v=np.zeros_like(self.v)
                for kk,g in zip(idx[jj], (v_pred-ret[jj])/len(jj)): g_v[kk]+=g
                self._adam_update(self.theta, g_theta, self.m_t, self.v_t, lr=a.ppo_lr, max_norm=a.max_grad_norm)
                self._adam_update(self.v,     g_v,     self.m_v, self.v_v, lr=a.ppo_lr*0.5, max_norm=a.max_grad_norm)
    def probs(self):
        p=1.0/(1.0+np.exp(-np.clip(self.theta, -self.args.logit_cap, self.args.logit_cap)))
        return np.clip(p, self.args.prob_eps, 1.0-self.args.prob_eps)

# --------------------------- Episodes ---------------------------
def draw_cards():
    a=random.randrange(3); b=random.randrange(2); 
    if b>=a: b+=1
    return a,b

def simulate_episode(learner_role, theta_probs, opp_probs):
    c1,c2=draw_cards(); decs=[]
    if learner_role==0:
        p_b1=theta_probs[0+c1]; a_b1=1 if random.random()<p_b1 else 0; decs.append((0+c1, a_b1))
        if a_b1==1:
            p_call2=opp_probs[0+c2]; a_call2=1 if random.random()<p_call2 else 0
            r=(2.0 if c1>c2 else -2.0) if a_call2==1 else 1.0; return r,decs
        else:
            p_b2=opp_probs[3+c2]; a_b2=1 if random.random()<p_b2 else 0
            if a_b2==1:
                p_call1=theta_probs[3+c1]; a_call1=1 if random.random()<p_call1 else 0; decs.append((3+c1, a_call1))
                r=(2.0 if c1>c2 else -2.0) if a_call1==1 else -1.0; return r,decs
            else:
                r=1.0 if c1>c2 else -1.0; return r,decs
    else:
        p_b1=opp_probs[0+c1]; a_b1=1 if random.random()<p_b1 else 0
        if a_b1==1:
            p_call2=theta_probs[0+c2]; a_call2=1 if random.random()<p_call2 else 0; decs.append((0+c2, a_call2))
            r_p1=(2.0 if c1>c2 else -2.0) if a_call2==1 else 1.0; return -r_p1,decs
        else:
            p_b2=theta_probs[3+c2]; a_b2=1 if random.random()<p_b2 else 0; decs.append((3+c2, a_b2))
            if a_b2==1:
                p_call1=opp_probs[3+c1]; a_call1=1 if random.random()<p_call1 else 0
                r_p1=(2.0 if c1>c2 else -2.0) if a_call1==1 else -1.0; return -r_p1,decs
            else:
                r_p1=1.0 if c1>c2 else -1.0; return -r_p1,decs

# --------------------------- Runner ---------------------------
class AlphaPSRORunner:
    def __init__(self, args, seed:int):
        self.args=args; self.seed=seed
        self._setup_state()
    def _setup_state(self):
        _seed_everything(self.seed)
        self.Z=[ [np.array([0.5]*6, dtype=np.float64)], [np.array([0.5]*6, dtype=np.float64)] ]
    def ev_matrix(self):
        K1,K2=len(self.Z[0]), len(self.Z[1])
        M=np.zeros((K1,K2), dtype=np.float64)
        for i in range(K1):
            for j in range(K2):
                M[i,j]=ev_p1_vs(self.Z[0][i], self.Z[1][j])
        return M
    def nashconv(self, s1, s2):
        M=self.ev_matrix()
        val=float(s1 @ M @ s2)
        mix_p2=sum(s2[j]*self.Z[1][j] for j in range(len(self.Z[1])))
        mix_p1=sum(s1[i]*self.Z[0][i] for i in range(len(self.Z[0])))
        br1=max(ev_p1_vs(pi, mix_p2) for pi in PURE1)
        br2min=min(ev_p1_vs(mix_p1, pj) for pj in PURE2)
        return max(0.0, br1 - br2min), val
    def ensure_match_mix(self, learner_role, mix):
        K=len(self.Z[1-learner_role])
        m=np.asarray(mix, dtype=np.float64).copy()
        if len(m)!=K: m=np.pad(m, (0, max(0,K-len(m))))[:K]
        s=m.sum()
        if not np.isfinite(s) or s<=0: return np.ones(K)/K
        return m/s
    def collect_br_data(self, learner_role, sigma_opp, rollouts):
        sigma_opp=self.ensure_match_mix(learner_role, sigma_opp)
        data={"idx":[], "act":[], "logp_old":[], "ret":[], "adv":[]}
        agent=TabularPPO(learner_role, self.args)
        for _ in range(rollouts):
            j=np.random.choice(len(self.Z[1-learner_role]), p=sigma_opp)
            opp=self.Z[1-learner_role][j]
            p_now=agent.probs()
            r,decs=simulate_episode(learner_role, p_now, opp)
            for (idx,a) in decs:
                logp=agent._logp(idx,a); v_pred=agent.v[idx]; adv=r - v_pred
                data["idx"].append(idx); data["act"].append(a); data["logp_old"].append(logp)
                data["ret"].append(r);   data["adv"].append(adv)
        return agent, data
    def train_br_with_ppo(self, role, sigma_opp):
        sigma_opp=self.ensure_match_mix(role, sigma_opp)
        agent,batch=self.collect_br_data(role, sigma_opp, self.args.ppo_rollouts)
        if len(batch["idx"])==0: return np.full(6,0.5)
        agent.ppo_update(batch); return agent.probs()
    def maybe_evict_least_mass(self, p, mix):
        if self.args.kmax<=0: return
        if len(self.Z[p])<self.args.kmax: return
        idx=int(np.argmin(mix)); self.Z[p].pop(idx)
    def add_anchor(self, p, vec):
        vec=np.clip(vec, self.args.prob_eps, 1.0-self.args.prob_eps)
        self.Z[p].append(vec.copy())
    def run(self, csv_path:str, header=True):
        os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)
        if header:
            print(f"[α-PSRO] seed={self.seed} | PPO LR={self.args.ppo_lr} | alpha={self.args.alpha}")
        hist_nc, hist_val, hist_dt, hist_mem, hist_n1, hist_n2=[],[],[],[],[],[]
        mem_type_rec="n/a"
        with open(csv_path,"w",newline="") as fcsv:
            w=csv.writer(fcsv)
            w.writerow(["iter","timestamp","n_strats_p1","n_strats_p2","time_sec","mem_mb","mem_type","nashconv","mix_ev_p1"])
            for it in range(1, self.args.iters+1):
                t0=time.time()
                M=self.ev_matrix(); s1,s2=alpha_rank_meta(M, alpha=self.args.alpha)
                self.maybe_evict_least_mass(0, s1)
                br1=self.train_br_with_ppo(role=0, sigma_opp=s2); self.add_anchor(0, br1)
                M=self.ev_matrix(); s1,s2=alpha_rank_meta(M, alpha=self.args.alpha)
                self.maybe_evict_least_mass(1, s2)
                br2=self.train_br_with_ppo(role=1, sigma_opp=s1); self.add_anchor(1, br2)
                M=self.ev_matrix(); s1,s2=alpha_rank_meta(M, alpha=self.args.alpha)
                nc,val=self.nashconv(s1,s2)
                dt=time.time()-t0; mem,mtype=_mem_mb(); mem_type_rec=mtype
                hist_nc.append(nc); hist_val.append(val); hist_dt.append(dt); hist_mem.append(mem)
                hist_n1.append(len(self.Z[0])); hist_n2.append(len(self.Z[1]))
                print(f"[α-PSRO] seed={self.seed} iter {it}/{self.args.iters} | P1={len(self.Z[0])} P2={len(self.Z[1])} | "
                      f"NashConv={nc:.5f} val={val:+.5f} | {dt:.2f}s mem={mem:.1f}MB")
                w.writerow([it, time.strftime("%Y-%m-%d %H:%M:%S"),
                            len(self.Z[0]), len(self.Z[1]),
                            f"{dt:.3f}", f"{mem:.2f}", mtype,
                            f"{nc:.6f}", f"{val:+.6f}"])
                fcsv.flush()
        return {
            "nc":np.array(hist_nc), "val":np.array(hist_val),
            "dt":np.array(hist_dt), "mem":np.array(hist_mem),
            "n1":np.array(hist_n1), "n2":np.array(hist_n2),
            "mem_type":mem_type_rec
        }

# --------------------------- Aggregate helpers ---------------------------
def _save_meanstd_csv(path, T, mem_type, dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std, n1_mean, n1_std, n2_mean, n2_std):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path,"w",newline="") as f:
        w=csv.writer(f)
        w.writerow(["iter","dt_mean","dt_std","mem_mb_mean","mem_mb_std",
                    "mix_ev_mean","mix_ev_std","nashconv_mean","nashconv_std",
                    "n_strats_p1_mean","n_strats_p1_std","n_strats_p2_mean","n_strats_p2_std","mem_type"])
        for i in range(T):
            w.writerow([i+1, f"{dt_mean[i]:.6f}", f"{dt_std[i]:.6f}",
                        f"{mem_mean[i]:.3f}", f"{mem_std[i]:.3f}",
                        f"{ev_mean[i]:+.6f}", f"{ev_std[i]:.6f}",
                        f"{nc_mean[i]:.6f}", f"{nc_std[i]:.6f}",
                        f"{n1_mean[i]:.3f}", f"{n1_std[i]:.3f}",
                        f"{n2_mean[i]:.3f}", f"{n2_std[i]:.3f}",
                        mem_type])

# --------------------------- Entry ---------------------------
def main():
    args=parse_args()
    seeds=_parse_seeds(args.seeds)
    os.makedirs(args.outdir, exist_ok=True)

    per=[]
    for s in seeds:
        r=AlphaPSRORunner(args, seed=s)
        csv_path=os.path.join(args.outdir, f"{args.csv_base}_seed{s}.csv")
        per.append(r.run(csv_path, header=(s==seeds[0])))

    T=args.iters
    for rec in per: assert len(rec["nc"])==T and len(rec["val"])==T

    nc_stack=np.stack([rec["nc"] for rec in per],0)
    ev_stack=np.stack([rec["val"] for rec in per],0)
    dt_stack=np.stack([rec["dt"] for rec in per],0)
    mem_stack=np.stack([rec["mem"] for rec in per],0)
    n1_stack=np.stack([rec["n1"] for rec in per],0)
    n2_stack=np.stack([rec["n2"] for rec in per],0)

    nc_mean,nc_std=nc_stack.mean(0), nc_stack.std(0,ddof=1)
    ev_mean,ev_std=ev_stack.mean(0), ev_stack.std(0,ddof=1)
    dt_mean,dt_std=dt_stack.mean(0), dt_stack.std(0,ddof=1)
    mem_mean,mem_std=mem_stack.mean(0), mem_stack.std(0,ddof=1)
    n1_mean,n1_std=n1_stack.mean(0), n1_stack.std(0,ddof=1)
    n2_mean,n2_std=n2_stack.mean(0), n2_stack.std(0,ddof=1)

    agg_csv=os.path.join(args.outdir, f"{args.csv_base}_meanstd.csv")
    _save_meanstd_csv(agg_csv, T, per[0]["mem_type"], dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std, n1_mean, n1_std, n2_mean, n2_std)
    print(f"[AGG] wrote mean/std CSV: {agg_csv}")
    print(f"[AGG] final (iter {T}) — EV mean={ev_mean[-1]:+.4f} ± {ev_std[-1]:.4f} | "
          f"NashConv mean={nc_mean[-1]:.4f} ± {nc_std[-1]:.4f} | "
          f"P1 strats ~{n1_mean[-1]:.2f}, P2 strats ~{n2_mean[-1]:.2f}")

    if not args.no_plots:
        its=np.arange(1,T+1)
        plt.figure(); plt.plot(its, nc_mean, label="NashConv (mean)")
        plt.fill_between(its, nc_mean-nc_std, nc_mean+nc_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Exploitability (α-PSRO) — mean ± std")

        plt.figure(); plt.plot(its, ev_mean, label="E[P1 payoff] (mean)")
        plt.fill_between(its, ev_mean-ev_std, ev_mean+ev_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Mixture Value (α-PSRO) — mean ± std")
        plt.show()

if __name__=="__main__":
    main()
