import argparse, time, csv, os, random, math, sys, glob
import numpy as np
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

try:
    from scipy.stats import ttest_ind
except Exception:
    ttest_ind = None



def parse_args():
    p = argparse.ArgumentParser("P2SRO / Pipeline PSRO for Kuhn Poker — multi-seed")

    p.add_argument("--iters", type=int, default=40)
    p.add_argument("--kmax", type=int, default=41)

    p.add_argument("--levels", type=int, default=4, help="Max number of active levels in the pipeline.")
    p.add_argument("--freeze_every", type=int, default=1, help="Freeze lowest active policy every N outer iters.")

    p.add_argument("--br_steps", type=int, default=50, help="Gradient steps per active policy per outer iter.")
    p.add_argument("--br_lr", type=float, default=5e-4)
    p.add_argument("--tau", type=float, default=1.0)
    p.add_argument("--init_std", type=float, default=1.0)
    p.add_argument("--clip_grad", type=float, default=1.0)
    p.add_argument("--prob_eps", type=float, default=1e-6)

    p.add_argument("--fp_iters", type=int, default=400, help="Fictitious-play iterations for meta Nash.")
    p.add_argument("--fp_smooth", type=float, default=1e-9, help="Smoothing for FP averages (avoid zeros).")

    p.add_argument("--min_fixed", type=int, default=1, help="Minimum fixed policies before freezing begins.")

    p.add_argument("--outdir", type=str, default="results")
    p.add_argument("--csv_base", type=str, default="p2psro_kuhn")
    p.add_argument("--seeds", type=str, default="0,1,2,3,4")
    p.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto")
    p.add_argument("--no_plots", action="store_true")

    p.add_argument("--meta_every", type=int, default=10)
    p.add_argument("--plateau_check_every", type=int, default=10)
    p.add_argument("--plateau_eps", type=float, default=1e-4)
    p.add_argument("--plateau_patience", type=int, default=3)


    p.add_argument("--ttest_against_glob", type=str, default=None,
                   help='Glob for baseline seed CSVs (same iters), e.g. "results/psro_kuhn_seed*.csv"')

    return p.parse_args()



def _seed_everything(s: int):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

def _pick_device(argdev: str):
    if argdev == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(argdev)

def _mem_mb():
    try:
        import psutil
        return psutil.Process().memory_info().rss / (1024**2), "rss"
    except Exception:
        try:
            import resource
            if sys.platform == "darwin":
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2), "ru_maxrss"
            else:
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0, "ru_maxrss"
        except Exception:
            return float("nan"), "n/a"

def bernoulli_kl(p, q, eps=1e-6):
    p = torch.clamp(p, eps, 1.0-eps)
    q = torch.clamp(q, eps, 1.0-eps)
    return p * torch.log(p/q) + (1.0-p) * torch.log((1.0-p)/(1.0-q))



CARDS = [0, 1, 2]

def ev_p1_vs_numpy(p1_vec6, p2_vec6):
    b1, c1 = p1_vec6[:3], p1_vec6[3:]
    c2, b2 = p2_vec6[:3], p2_vec6[3:]
    ev = 0.0
    for c1i in CARDS:
        for c2i in CARDS:
            if c1i == c2i: continue
            s4 = 2.0 if c1i > c2i else -2.0
            s2 = 1.0 if c1i > c2i else -1.0
            term_B = b1[c1i] * (c2[c2i]*s4 + (1.0-c2[c2i])*1.0)
            term_C = (1.0-b1[c1i]) * ( b2[c2i]*( c1[c1i]*s4 + (1.0-c1[c1i])*(-1.0) ) + (1.0-b2[c2i])*s2 )
            ev += term_B + term_C
    return ev / 6.0

def ev_p1_vs_torch(p1_probs6, p2_probs6):
    b1, c1 = p1_probs6[:3], p1_probs6[3:]
    c2, b2 = p2_probs6[:3], p2_probs6[3:]
    ev = torch.zeros([], dtype=torch.float32, device=p1_probs6.device)
    for c1i in CARDS:
        for c2i in CARDS:
            if c1i == c2i: continue
            s4 = 2.0 if c1i > c2i else -2.0
            s2 = 1.0 if c1i > c2i else -1.0
            term_B = b1[c1i] * (c2[c2i]*s4 + (1.0-c2[c2i])*1.0)
            term_C = (1.0-b1[c1i]) * ( b2[c2i]*( c1[c1i]*s4 + (1.0-c1[c1i])*(-1.0) ) + (1.0-b2[c2i])*s2 )
            ev = ev + term_B + term_C
    return ev / 6.0

def _enumerate_pure_policies_p1():
    out = []
    for mask in range(64):
        b = [(mask>>0)&1, (mask>>1)&1, (mask>>2)&1]
        c = [(mask>>3)&1, (mask>>4)&1, (mask>>5)&1]
        out.append(np.array(b+c, dtype=np.float64))
    return out

def _enumerate_pure_policies_p2():
    out = []
    for mask in range(64):
        c = [(mask>>0)&1, (mask>>1)&1, (mask>>2)&1]
        b = [(mask>>3)&1, (mask>>4)&1, (mask>>5)&1]
        out.append(np.array(c+b, dtype=np.float64))
    return out

PURE1 = _enumerate_pure_policies_p1()
PURE2 = _enumerate_pure_policies_p2()


def _solve_meta_zero_sum(M: np.ndarray, iters: int, smooth: float):
    s1, s2 = fictitious_play_zero_sum(M, iters=iters, smooth=smooth)
    s1 = s1 / (s1.sum() + 1e-12)
    s2 = s2 / (s2.sum() + 1e-12)
    return s1, s2



class BernoulliPolicy(nn.Module):
    def __init__(self, init_std=1.0):
        super().__init__()
        self.logits = nn.Parameter(torch.zeros(6, dtype=torch.float32))
        nn.init.normal_(self.logits, mean=0.0, std=init_std)

    def probs(self, tau: float, eps: float):
        p = torch.sigmoid(self.logits / max(1e-8, tau))
        p = torch.clamp(torch.nan_to_num(p, nan=0.5), eps, 1.0-eps)
        return p



def fictitious_play_zero_sum(M: np.ndarray, iters: int, smooth: float = 1e-9):
    K1, K2 = M.shape
    c1 = np.ones(K1, dtype=np.float64)
    c2 = np.ones(K2, dtype=np.float64)

    for t in range(1, int(iters)+1):
        s1 = (c1 + smooth); s1 /= s1.sum()
        s2 = (c2 + smooth); s2 /= s2.sum()

        v_rows = M @ s2
        i_br = int(np.argmax(v_rows))
        v_cols = s1 @ M
        j_br = int(np.argmin(v_cols))

        c1[i_br] += 1.0
        c2[j_br] += 1.0

    s1 = (c1 + smooth); s1 /= s1.sum()
    s2 = (c2 + smooth); s2 /= s2.sum()
    return s1, s2



class P2PSRORunner:
    def __init__(self, args, seed: int, device: torch.device):
        self.args = args
        self.seed = seed
        self.device = device
        self._setup_state()

    def _setup_state(self):
        _seed_everything(self.seed)

        self.max_k = int(self.args.kmax)
        self.levels = int(self.args.levels)

        self.F1, self.F2 = [], []
        self.F1.append(np.clip(np.random.rand(6), 1e-3, 1-1e-3).astype(np.float64))
        self.F2.append(np.clip(np.random.rand(6), 1e-3, 1-1e-3).astype(np.float64))

        self.A1, self.A2 = [], []
        self.optA1, self.optA2 = [], []

        self.meta_cache = {}

        self.lowest_plateau = {
            "last_check_iter": 0,
            "last_score": None,
            "bad_count": 0,
        }

        self.payoff_cache = {}

        self._spawn_active()
        self._rebuild_opts()

        self._recompute_all_meta(it=0, force=True)

    def _spawn_active(self):
        if len(self.A1) >= self.levels:
            return
        p1 = BernoulliPolicy(init_std=self.args.init_std).to(self.device)
        p2 = BernoulliPolicy(init_std=self.args.init_std).to(self.device)
        self.A1.append(p1)
        self.A2.append(p2)

    def _rebuild_opts(self):
        self.optA1 = [torch.optim.Adam(p.parameters(), lr=self.args.br_lr) for p in self.A1]
        self.optA2 = [torch.optim.Adam(p.parameters(), lr=self.args.br_lr) for p in self.A2]

    @torch.no_grad()
    def _snapshot_policy(self, pol: BernoulliPolicy):
        return pol.probs(self.args.tau, self.args.prob_eps).detach().cpu().numpy().astype(np.float64)

    def _ev_cached(self, p1_np: np.ndarray, p2_np: np.ndarray) -> float:
        k = (p1_np.tobytes(), p2_np.tobytes())
        if k in self.payoff_cache:
            return self.payoff_cache[k]
        v = float(ev_p1_vs_numpy(p1_np, p2_np))
        self.payoff_cache[k] = v
        return v

    def _matrix_from_sets(self, P1_list, P2_list):
        K1, K2 = len(P1_list), len(P2_list)
        M = np.zeros((K1, K2), dtype=np.float64)
        for i in range(K1):
            for j in range(K2):
                M[i, j] = self._ev_cached(P1_list[i], P2_list[j])
        return M

    def _support_set_for_level(self, j: int):
        P1 = list(self.F1)
        P2 = list(self.F2)
        if j > 0:
            with torch.no_grad():
                for k in range(j):
                    P1.append(self._snapshot_policy(self.A1[k]))
                    P2.append(self._snapshot_policy(self.A2[k]))
        return P1, P2

    def _meta_for_level(self, j: int, it: int, force: bool = False):
        meta_every = int(self.args.meta_every)
        rec = self.meta_cache.get(j, None)
        if (not force) and (rec is not None):
            _, _, _, _, last_it = rec
            if (it - last_it) < meta_every:
                return rec[0], rec[1], rec[2], rec[3]

        P1, P2 = self._support_set_for_level(j)
        M = self._matrix_from_sets(P1, P2)
        s1, s2 = _solve_meta_zero_sum(M, iters=int(self.args.fp_iters), smooth=float(self.args.fp_smooth))

        self.meta_cache[j] = (P1, P2, s1, s2, it)
        return P1, P2, s1, s2

    def _recompute_all_meta(self, it: int, force: bool = False):
        for j in range(len(self.A1)):
            self._meta_for_level(j, it=it, force=force)

    def _sample_index(self, probs: np.ndarray) -> int:
        p = np.asarray(probs, dtype=np.float64)
        p = np.maximum(p, 0)
        p = p / (p.sum() + 1e-12)
        return int(np.random.choice(len(p), p=p))

    def _train_active_level(self, j: int, it: int):
        P1_set, P2_set, s1, s2 = self._meta_for_level(j, it=it, force=False)

        p1 = self.A1[j]
        opt1 = self.optA1[j]
        p1.train()
        for _ in range(int(self.args.br_steps)):
            opp_idx = self._sample_index(s2)
            opp_np = P2_set[opp_idx]
            opp = torch.tensor(opp_np, dtype=torch.float32, device=self.device)

            probs1 = p1.probs(self.args.tau, self.args.prob_eps)
            u1 = ev_p1_vs_torch(probs1, opp.detach())
            loss = -u1

            opt1.zero_grad(set_to_none=True)
            loss.backward()
            if self.args.clip_grad and float(self.args.clip_grad) > 0:
                torch.nn.utils.clip_grad_norm_(p1.parameters(), float(self.args.clip_grad))
            opt1.step()
        p1.eval()

        p2 = self.A2[j]
        opt2 = self.optA2[j]
        p2.train()
        for _ in range(int(self.args.br_steps)):
            opp_idx = self._sample_index(s1)
            opp_np = P1_set[opp_idx]
            opp = torch.tensor(opp_np, dtype=torch.float32, device=self.device)

            probs2 = p2.probs(self.args.tau, self.args.prob_eps)
            u1m = ev_p1_vs_torch(opp.detach(), probs2)  # u1(opp, p2)
            u2 = -u1m
            loss = -u2

            opt2.zero_grad(set_to_none=True)
            loss.backward()
            if self.args.clip_grad and float(self.args.clip_grad) > 0:
                torch.nn.utils.clip_grad_norm_(p2.parameters(), float(self.args.clip_grad))
            opt2.step()
        p2.eval()

    @torch.no_grad()
    def _all_policies_np(self):
        P1 = self.F1 + [self._snapshot_policy(p) for p in self.A1]
        P2 = self.F2 + [self._snapshot_policy(p) for p in self.A2]
        return P1, P2

    @torch.no_grad()
    def _overall_meta(self):
        P1, P2 = self._all_policies_np()
        M = self._matrix_from_sets(P1, P2)
        s1, s2 = _solve_meta_zero_sum(M, iters=int(self.args.fp_iters), smooth=float(self.args.fp_smooth))
        val = float(s1 @ M @ s2)
        mix1 = np.sum(np.stack(P1, 0) * s1[:, None], 0)
        mix2 = np.sum(np.stack(P2, 0) * s2[:, None], 0)
        return P1, P2, s1, s2, val, mix1, mix2

    @torch.no_grad()
    def nashconv(self):
        P1, P2, s1, s2, val, mix1, mix2 = self._overall_meta()
        br1 = max(ev_p1_vs_numpy(pi, mix2) for pi in PURE1)
        br2min = min(ev_p1_vs_numpy(mix1, pj) for pj in PURE2)
        return max(0.0, br1 - br2min), val

    @torch.no_grad()
    def _lowest_training_score(self, it: int) -> float:
        if len(self.A1) == 0:
            return 0.0
        P1_set, P2_set, s1, s2 = self._meta_for_level(0, it=it, force=False)
        mix2 = np.sum(np.stack(P2_set, 0) * s2[:, None], 0)
        p1_np = self._snapshot_policy(self.A1[0])
        return float(ev_p1_vs_numpy(p1_np, mix2))

    def _check_plateau_and_freeze(self, it: int):
        if len(self.A1) == 0:
            return
        check_every = int(self.args.plateau_check_every)
        if (it - self.lowest_plateau["last_check_iter"]) < check_every:
            return

        score = self._lowest_training_score(it=it)
        last = self.lowest_plateau["last_score"]
        self.lowest_plateau["last_check_iter"] = it

        if last is None:
            self.lowest_plateau["last_score"] = score
            self.lowest_plateau["bad_count"] = 0
            return

        improvement = score - last
        self.lowest_plateau["last_score"] = score

        if improvement < float(self.args.plateau_eps):
            self.lowest_plateau["bad_count"] += 1
        else:
            self.lowest_plateau["bad_count"] = 0

        if self.lowest_plateau["bad_count"] >= int(self.args.plateau_patience):
            self._freeze_lowest()
            self.lowest_plateau = {"last_check_iter": it, "last_score": None, "bad_count": 0}
            self._recompute_all_meta(it=it, force=True)

    def _freeze_lowest(self):
        if len(self.A1) == 0:
            return

        with torch.no_grad():
            p1_np = self._snapshot_policy(self.A1[0])
            p2_np = self._snapshot_policy(self.A2[0])

        if len(self.F1) < self.max_k:
            self.F1.append(p1_np)
        if len(self.F2) < self.max_k:
            self.F2.append(p2_np)

        self.A1.pop(0)
        self.A2.pop(0)

        if (len(self.F1) < self.max_k) or (len(self.F2) < self.max_k):
            self._spawn_active()

        self._rebuild_opts()

        self.meta_cache = {}

    def run(self, csv_path: str, header=True):
        os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)
        if header:
            print(f"[P2PSRO] seed={self.seed} | device={self.device} | levels={self.levels} | br_steps={self.args.br_steps}")

        hist_nc, hist_val, hist_dt, hist_mem = [], [], [], []
        hist_n1, hist_n2 = [], []
        mem_type_rec = "n/a"

        with open(csv_path, "w", newline="") as fcsv:
            w = csv.writer(fcsv)
            w.writerow(["iter","timestamp","n_strats_p1","n_strats_p2","time_sec","mem_mb","mem_type","nashconv","mix_ev_p1"])

            for it in range(1, int(self.args.iters)+1):
                t0 = time.time()

                if (it % int(self.args.meta_every) == 0):
                    self._recompute_all_meta(it=it, force=True)

                for j in range(len(self.A1)):
                    self._train_active_level(j, it=it)

                self._check_plateau_and_freeze(it=it)

                nc, val = self.nashconv()
                dt = time.time() - t0
                mem, mtype = _mem_mb()
                mem_type_rec = mtype

                P1_all, P2_all = self._all_policies_np()
                hist_nc.append(nc); hist_val.append(val)
                hist_dt.append(dt); hist_mem.append(mem)
                hist_n1.append(len(P1_all)); hist_n2.append(len(P2_all))

                print(f"[P2PSRO] iter {it}/{self.args.iters} | "
                      f"|Π1|={len(P1_all)} |Π2|={len(P2_all)} | "
                      f"NashConv={nc:.5f} val={val:+.5f} | {dt:.2f}s mem={mem:.1f}MB")

                w.writerow([it, time.strftime("%Y-%m-%d %H:%M:%S"),
                            len(P1_all), len(P2_all), f"{dt:.3f}", f"{mem:.2f}", mtype,
                            f"{nc:.6f}", f"{val:+.6f}"])
                fcsv.flush()

        return {
            "nc": np.array(hist_nc, dtype=np.float64),
            "val": np.array(hist_val, dtype=np.float64),
            "dt": np.array(hist_dt, dtype=np.float64),
            "mem": np.array(hist_mem, dtype=np.float64),
            "n1": np.array(hist_n1, dtype=np.float64),
            "n2": np.array(hist_n2, dtype=np.float64),
            "mem_type": mem_type_rec,
        }


def _save_meanstd_csv(path, T, mem_type,
                      dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std,
                      n1_mean, n1_std, n2_mean, n2_std):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["iter",
                    "dt_mean","dt_std",
                    "mem_mb_mean","mem_mb_std",
                    "mix_ev_mean","mix_ev_std",
                    "nashconv_mean","nashconv_std",
                    "n_strats_p1_mean","n_strats_p1_std",
                    "n_strats_p2_mean","n_strats_p2_std",
                    "mem_type"])
        for i in range(T):
            w.writerow([i+1,
                        f"{dt_mean[i]:.6f}", f"{dt_std[i]:.6f}",
                        f"{mem_mean[i]:.6f}", f"{mem_std[i]:.6f}",
                        f"{ev_mean[i]:+.6f}", f"{ev_std[i]:.6f}",
                        f"{nc_mean[i]:.6f}", f"{nc_std[i]:.6f}",
                        f"{n1_mean[i]:.6f}", f"{n1_std[i]:.6f}",
                        f"{n2_mean[i]:.6f}", f"{n2_std[i]:.6f}",
                        mem_type])

def _read_final_nashconv(csv_path: str):
    with open(csv_path, "r") as f:
        rows = list(csv.reader(f))
    last = rows[-1]
    return float(last[7])

def _welch_ttest(a: np.ndarray, b: np.ndarray):
    if ttest_ind is None:
        return float("nan"), float("nan")
    t, p = ttest_ind(a, b, equal_var=False)
    return float(t), float(p)



def main():
    args = parse_args()
    device = _pick_device(args.device)
    seeds = [int(x.strip()) for x in args.seeds.split(",") if x.strip() != ""]

    os.makedirs(args.outdir, exist_ok=True)

    per = []
    seed_csvs = []
    for s in seeds:
        r = P2PSRORunner(args, seed=s, device=device)
        csv_path = os.path.join(args.outdir, f"{args.csv_base}_seed{s}.csv")
        seed_csvs.append(csv_path)
        per.append(r.run(csv_path, header=(s == seeds[0])))

    T = int(args.iters)
    nc_stack = np.stack([rec["nc"] for rec in per], 0)
    ev_stack = np.stack([rec["val"] for rec in per], 0)
    dt_stack = np.stack([rec["dt"] for rec in per], 0)
    mem_stack = np.stack([rec["mem"] for rec in per], 0)
    n1_stack = np.stack([rec["n1"] for rec in per], 0)
    n2_stack = np.stack([rec["n2"] for rec in per], 0)

    nc_mean, nc_std = nc_stack.mean(0), nc_stack.std(0, ddof=1)
    ev_mean, ev_std = ev_stack.mean(0), ev_stack.std(0, ddof=1)
    dt_mean, dt_std = dt_stack.mean(0), dt_stack.std(0, ddof=1)
    mem_mean, mem_std = mem_stack.mean(0), mem_stack.std(0, ddof=1)
    n1_mean, n1_std = n1_stack.mean(0), n1_stack.std(0, ddof=1)
    n2_mean, n2_std = n2_stack.mean(0), n2_stack.std(0, ddof=1)

    agg_csv = os.path.join(args.outdir, f"{args.csv_base}_meanstd.csv")
    _save_meanstd_csv(agg_csv, T, per[0]["mem_type"],
                      dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std,
                      n1_mean, n1_std, n2_mean, n2_std)
    print(f"[AGG] wrote mean/std CSV: {agg_csv}")
    print(f"[AGG] final (iter {T}) — EV mean={ev_mean[-1]:+.4f} ± {ev_std[-1]:.4f} | "
          f"NashConv mean={nc_mean[-1]:.4f} ± {nc_std[-1]:.4f} | "
          f"|Π1| mean={n1_mean[-1]:.2f}, |Π2| mean={n2_mean[-1]:.2f}")

    if args.ttest_against_glob is not None:
        base_paths = sorted(glob.glob(args.ttest_against_glob))
        n = min(len(base_paths), len(seed_csvs))
        a = np.array([_read_final_nashconv(p) for p in seed_csvs[:n]], dtype=np.float64)
        b = np.array([_read_final_nashconv(p) for p in base_paths[:n]], dtype=np.float64)
        t, pval = _welch_ttest(a, b)
        print(f"[TTEST] Welch t-test on final NashConv (P2PSRO vs baseline): "
              f"P2PSRO mean={a.mean():.4f}±{a.std(ddof=1):.4f}, "
              f"baseline mean={b.mean():.4f}±{b.std(ddof=1):.4f} | t={t:.4f}, p={pval:.4g}")

    if not args.no_plots:
        its = np.arange(1, T+1)
        plt.figure()
        plt.plot(its, nc_mean, label="NashConv (mean)")
        plt.fill_between(its, nc_mean-nc_std, nc_mean+nc_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Exploitability (P2PSRO) — mean ± std")

        plt.figure()
        plt.plot(its, ev_mean, label="E[P1 payoff] (mean)")
        plt.fill_between(its, ev_mean-ev_std, ev_mean+ev_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Mixture Value (P2PSRO) — mean ± std")

    nash_path_png = os.path.join(args.outdir, f"{args.csv_base}_nashconv_meanstd.png")
    ev_path_png   = os.path.join(args.outdir, f"{args.csv_base}_mixev_meanstd.png")

    figs = [plt.figure(n) for n in plt.get_fignums()]
    if len(figs) >= 2:
        figs[-2].tight_layout()
        figs[-2].savefig(nash_path_png, dpi=200)
        plt.close(figs[-2])

        figs[-1].tight_layout()
        figs[-1].savefig(ev_path_png, dpi=200)
        plt.close(figs[-1])
    else:
        plt.gcf().tight_layout()
        plt.gcf().savefig(nash_path_png, dpi=200)
        plt.close(plt.gcf())

    print(f"[PLOTS] saved: {nash_path_png}")
    print(f"[PLOTS] saved: {ev_path_png}")



if __name__ == "__main__":
    main()
