import argparse, time, csv, os, random, math, sys
import numpy as np
import matplotlib.pyplot as plt

# --------------------------- Args ---------------------------
def parse_args():
    p = argparse.ArgumentParser("Double Oracle for Kuhn Poker — multi-seed")
    p.add_argument("--iters", type=int, default=40)
    p.add_argument("--meta_loops", type=int, default=300)
    p.add_argument("--eta", type=float, default=0.3)
    p.add_argument("--eta_sched", choices=["const","sqrt","harmonic"], default="harmonic")
    p.add_argument("--kmax", type=int, default=0, help="0=unbounded; >0 cap per-player pool (evict least-mass)")
    p.add_argument("--prob_eps", type=float, default=1e-12)
    p.add_argument("--logit_cap", type=float, default=50.0)
    p.add_argument("--outdir", type=str, default=".")
    p.add_argument("--csv_base", type=str, default="double_oracle_kuhn")
    p.add_argument("--seeds", type=str, default="0,1,2,3,4")
    p.add_argument("--no_plots", action="store_true")
    return p.parse_args()

# --------------------------- Utils ---------------------------
def _seed_everything(s:int):
    random.seed(s); np.random.seed(s)

def _mem_mb():
    try:
        import psutil
        return psutil.Process().memory_info().rss / (1024**2), "rss"
    except Exception:
        try:
            import resource
            if sys.platform=="darwin":
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2), "ru_maxrss"
            else:
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0, "ru_maxrss"
        except Exception:
            return float("nan"), "n/a"

def _parse_seeds(s:str):
    out=[int(t.strip()) for t in s.split(",") if t.strip()!='']
    return out if out else [0,1,2,3,4]

# --------------------------- Game ---------------------------
CARDS=[0,1,2]
def ev_p1_vs(p1,p2):
    b1,c1=p1[:3],p1[3:]; c2,b2=p2[:3],p2[3:]
    ev=0.0
    for c in CARDS:
        for d in CARDS:
            if c==d: continue
            s4=2.0 if c>d else -2.0; s2=1.0 if c>d else -1.0
            ev += b1[c]*(c2[d]*s4 + (1-c2[d])*1.0) + (1-b1[c])*(b2[d]*(c1[c]*s4 + (1-c1[c])*(-1.0)) + (1-b2[d])*s2)
    return ev/6.0

def enumerate_pure_policies_p1():
    out=[]
    for mask in range(64):
        b=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]; c=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(b+c, dtype=np.float64))
    return out
def enumerate_pure_policies_p2():
    out=[]
    for mask in range(64):
        c=[(mask>>0)&1,(mask>>1)&1,(mask>>2)&1]; b=[(mask>>3)&1,(mask>>4)&1,(mask>>5)&1]
        out.append(np.array(c+b, dtype=np.float64))
    return out
PURE1=enumerate_pure_policies_p1(); PURE2=enumerate_pure_policies_p2()

# --------------------------- Runner ---------------------------
class DORunner:
    def __init__(self, args, seed:int):
        self.args=args; self.seed=seed; self._setup_state()
    def _setup_state(self):
        _seed_everything(self.seed)
        # Start with trivial pure strategies (mask 0 => all zeros).
        self.Z=[ [PURE1[0].copy()], [PURE2[0].copy()] ]
        self.LOGS=[ np.zeros(1, dtype=np.float64), np.zeros(1, dtype=np.float64) ]
    def ev_matrix(self):
        K1,K2=len(self.Z[0]), len(self.Z[1]); M=np.zeros((K1,K2), dtype=np.float64)
        for i in range(K1):
            for j in range(K2):
                M[i,j]=ev_p1_vs(self.Z[0][i], self.Z[1][j])
        return M
    def softmax_np(self,x):
        x=np.clip(x, -self.args.logit_cap, self.args.logit_cap); z=x-x.max(); e=np.exp(z)
        return e/(e.sum()+1e-12)
    def eta_t(self,t):
        s=self.args.eta_sched; eta0=self.args.eta
        if s=="const": return eta0
        if s=="sqrt": return eta0/max(1.0, math.sqrt(t))
        if s=="harmonic": return eta0/(1.0+0.5*t)
        return eta0
    def mwu_meta(self,it):
        eta=self.eta_t(it)
        for _ in range(self.args.meta_loops):
            s1=self.softmax_np(self.LOGS[0]); s2=self.softmax_np(self.LOGS[1]); M=self.ev_matrix()
            v0=M @ s2; v1=-(s1 @ M)
            l0=np.log(s1+1e-12)+eta*v0; l0-=l0.max(); self.LOGS[0]=np.exp(l0); self.LOGS[0]/=self.LOGS[0].sum()
            l1=np.log(s2+1e-12)+eta*v1; l1-=l1.max(); self.LOGS[1]=np.exp(l1); self.LOGS[1]/=self.LOGS[1].sum()
        return self.softmax_np(self.LOGS[0]), self.softmax_np(self.LOGS[1])
    def nashconv(self, s1, s2):
        M=self.ev_matrix()
        val=float(s1 @ M @ s2)
        mix_p2=sum(s2[j]*self.Z[1][j] for j in range(len(self.Z[1])))
        mix_p1=sum(s1[i]*self.Z[0][i] for i in range(len(self.Z[0])))
        br1=max(ev_p1_vs(pi, mix_p2) for pi in PURE1)
        br2min=min(ev_p1_vs(mix_p1, pj) for pj in PURE2)
        return max(0.0, br1 - br2min), val
    def ensure_mix_len(self, mix, K):
        m=np.asarray(mix, dtype=np.float64)
        if len(m)!=K: m=np.pad(m, (0,max(0,K-len(m))))[:K]
        s=m.sum()
        if not np.isfinite(s) or s<=0: return np.ones(K)/K
        return m/s
    def pure_br1_against_mix(self, s2):
        K2=len(self.Z[1]); s2=self.ensure_mix_len(s2, K2)
        mix_p2=sum(s2[j]*self.Z[1][j] for j in range(K2))
        best=None; best_v=-1e18
        for pi in PURE1:
            v=ev_p1_vs(pi, mix_p2)
            if v>best_v: best_v=v; best=pi
        return best, best_v
    def pure_br2_against_mix(self, s1):
        K1=len(self.Z[0]); s1=self.ensure_mix_len(s1, K1)
        mix_p1=sum(s1[i]*self.Z[0][i] for i in range(K1))
        best=None; best_v=+1e18
        for pj in PURE2:
            v=ev_p1_vs(mix_p1, pj)  # P1 payoff; P2 minimizes
            if v<best_v: best_v=v; best=pj
        return best, best_v
    def maybe_evict_least_mass(self,p,s):
        if self.args.kmax<=0 or len(self.Z[p])<self.args.kmax: return
        idx=int(np.argmin(s)); self.Z[p].pop(idx); self.LOGS[p]=np.delete(self.LOGS[p], idx, axis=0)
    def add_pure_if_new(self,p,pure_vec):
        for z in self.Z[p]:
            if np.array_equal(z, pure_vec): return False
        self.Z[p].append(pure_vec.copy())
        eps=1e-6
        self.LOGS[p]=np.concatenate([self.LOGS[p]*(1-eps), np.array([eps],dtype=np.float64)],0)
        self.LOGS[p]/=self.LOGS[p].sum()
        return True
    def run(self, csv_path:str, header=True):
        os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)
        if header: print(f"[DO] seed={self.seed}")
        hist_nc,hist_val,hist_dt,hist_mem,hist_n1,hist_n2=[],[],[],[],[],[]
        mem_type_rec="n/a"
        with open(csv_path,"w",newline="") as fcsv:
            w=csv.writer(fcsv)
            w.writerow(["iter","timestamp","n_strats_p1","n_strats_p2","time_sec","mem_mb","mem_type","nashconv","mix_ev_p1"])
            for it in range(1, self.args.iters+1):
                t0=time.time()
                s1,s2=self.mwu_meta(it)
                self.maybe_evict_least_mass(0, s1)
                br1,_=self.pure_br1_against_mix(s2); added1=self.add_pure_if_new(0, br1)
                s1,s2=self.mwu_meta(it)
                self.maybe_evict_least_mass(1, s2)
                br2,_=self.pure_br2_against_mix(s1); added2=self.add_pure_if_new(1, br2)
                s1,s2=self.mwu_meta(it)
                nc,val=self.nashconv(s1,s2)
                dt=time.time()-t0; mem,mtype=_mem_mb(); mem_type_rec=mtype
                hist_nc.append(nc); hist_val.append(val); hist_dt.append(dt); hist_mem.append(mem)
                hist_n1.append(len(self.Z[0])); hist_n2.append(len(self.Z[1]))
                print(f"[DO] seed={self.seed} iter {it}/{self.args.iters} | P1={len(self.Z[0])} P2={len(self.Z[1])} | "
                      f"NashConv={nc:.5f} val={val:+.5f} | {dt:.2f}s mem={mem:.1f}MB | added P1={added1} P2={added2}")
                w.writerow([it, time.strftime("%Y-%m-%d %H:%M:%S"),
                            len(self.Z[0]), len(self.Z[1]),
                            f"{dt:.3f}", f"{mem:.2f}", mtype,
                            f"{nc:.6f}", f"{val:+.6f}"])
                fcsv.flush()
        return {"nc":np.array(hist_nc), "val":np.array(hist_val),
                "dt":np.array(hist_dt), "mem":np.array(hist_mem),
                "n1":np.array(hist_n1), "n2":np.array(hist_n2),
                "mem_type":mem_type_rec}

# --------------------------- Aggregate ---------------------------
def _save_meanstd_csv(path,T,mem_type,dt_mean,dt_std,mem_mean,mem_std,ev_mean,ev_std,nc_mean,nc_std,n1_mean,n1_std,n2_mean,n2_std):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path,"w",newline="") as f:
        w=csv.writer(f)
        w.writerow(["iter","dt_mean","dt_std","mem_mb_mean","mem_mb_std","mix_ev_mean","mix_ev_std",
                    "nashconv_mean","nashconv_std","n_strats_p1_mean","n_strats_p1_std",
                    "n_strats_p2_mean","n_strats_p2_std","mem_type"])
        for i in range(T):
            w.writerow([i+1, f"{dt_mean[i]:.6f}", f"{dt_std[i]:.6f}",
                        f"{mem_mean[i]:.3f}", f"{mem_std[i]:.3f}",
                        f"{ev_mean[i]:+.6f}", f"{ev_std[i]:.6f}",
                        f"{nc_mean[i]:.6f}", f"{nc_std[i]:.6f}",
                        f"{n1_mean[i]:.3f}", f"{n1_std[i]:.3f}",
                        f"{n2_mean[i]:.3f}", f"{n2_std[i]:.3f}", mem_type])

def main():
    args=parse_args(); seeds=_parse_seeds(args.seeds)
    os.makedirs(args.outdir, exist_ok=True)
    per=[]
    for s in seeds:
        r=DORunner(args, seed=s)
        csv_path=os.path.join(args.outdir, f"{args.csv_base}_seed{s}.csv")
        per.append(r.run(csv_path, header=(s==seeds[0])))
    T=args.iters
    for rec in per: assert len(rec["nc"])==T and len(rec["val"])==T
    nc_stack=np.stack([r["nc"] for r in per],0); ev_stack=np.stack([r["val"] for r in per],0)
    dt_stack=np.stack([r["dt"] for r in per],0); mem_stack=np.stack([r["mem"] for r in per],0)
    n1_stack=np.stack([r["n1"] for r in per],0); n2_stack=np.stack([r["n2"] for r in per],0)
    nc_mean,nc_std=nc_stack.mean(0), nc_stack.std(0,ddof=1)
    ev_mean,ev_std=ev_stack.mean(0), ev_stack.std(0,ddof=1)
    dt_mean,dt_std=dt_stack.mean(0), dt_stack.std(0,ddof=1)
    mem_mean,mem_std=mem_stack.mean(0), mem_stack.std(0,ddof=1)
    n1_mean,n1_std=n1_stack.mean(0), n1_stack.std(0,ddof=1)
    n2_mean,n2_std=n2_stack.mean(0), n2_stack.std(0,ddof=1)
    agg_csv=os.path.join(args.outdir, f"{args.csv_base}_meanstd.csv")
    _save_meanstd_csv(agg_csv,T, per[0]["mem_type"], dt_mean,dt_std,mem_mean,mem_std,ev_mean,ev_std,nc_mean,nc_std,n1_mean,n1_std,n2_mean,n2_std)
    print(f"[AGG] wrote mean/std CSV: {agg_csv}")
    print(f"[AGG] final (iter {T}) — EV mean={ev_mean[-1]:+.4f} ± {ev_std[-1]:.4f} | "
          f"NashConv mean={nc_mean[-1]:.4f} ± {nc_std[-1]:.4f} | "
          f"P1 strats ~{n1_mean[-1]:.2f}, P2 strats ~{n2_mean[-1]:.2f}")
    if not args.no_plots:
        its=np.arange(1,T+1)
        plt.figure(); plt.plot(its, nc_mean, label="NashConv (mean)")
        plt.fill_between(its, nc_mean-nc_std, nc_mean+nc_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Exploitability (DO) — mean ± std")
        plt.figure(); plt.plot(its, ev_mean, label="E[P1 payoff] (mean)")
        plt.fill_between(its, ev_mean-ev_std, ev_mean+ev_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Mixture Value (DO) — mean ± std")
        plt.show()

if __name__=="__main__":
    main()
