import argparse, time, csv, os, random, math, sys, glob
import numpy as np
import torch
import torch.nn as nn
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

try:
    from scipy.stats import ttest_ind
except Exception:
    ttest_ind = None



def parse_args():
    p = argparse.ArgumentParser("NeuPL (conditional network population) for Kuhn — multi-seed")

    p.add_argument("--iters", type=int, default=40)
    p.add_argument("--kmax", type=int, default=41)

    p.add_argument("--meta_loops", type=int, default=200)
    p.add_argument("--eta", type=float, default=0.1)
    p.add_argument("--eta_sched", choices=["const", "sqrt", "harmonic"], default="harmonic")

    p.add_argument("--zdim", type=int, default=16)
    p.add_argument("--tau", type=float, default=1.0)
    p.add_argument("--emb_init_std", type=float, default=1.0)

    p.add_argument("--abr_steps", type=int, default=50)
    p.add_argument("--abr_lr", type=float, default=5e-4)
    p.add_argument("--beta_kl", type=float, default=0.0, help="Optional trust-region penalty; set 0 for pure NeuPL.")
    p.add_argument("--train_weight", choices=["sigma", "uniform"], default="sigma",
                   help="How to weight per-policy ABR objective across the population.")

    p.add_argument("--clip_grad", type=float, default=1.0)
    p.add_argument("--logit_cap", type=float, default=50.0)
    p.add_argument("--prob_eps", type=float, default=1e-6)

    p.add_argument("--outdir", type=str, default="results")
    p.add_argument("--csv_base", type=str, default="neupl_kuhn")
    p.add_argument("--seeds", type=str, default="0,1,2,3,4")
    p.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto")
    p.add_argument("--no_plots", action="store_true")

    p.add_argument("--ttest_against_glob", type=str, default=None,
                   help='Glob for baseline seed CSVs (same iters), e.g. "results/psro_kuhn_seed*.csv"')

    return p.parse_args()



def _seed_everything(s: int):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

def _pick_device(argdev: str):
    if argdev == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(argdev)

def _mem_mb():
    try:
        import psutil
        return psutil.Process().memory_info().rss / (1024**2), "rss"
    except Exception:
        try:
            import resource
            if sys.platform == "darwin":
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2), "ru_maxrss"
            else:
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0, "ru_maxrss"
        except Exception:
            return float("nan"), "n/a"

def _softmax_np(x, cap=50.0):
    x = np.clip(x, -cap, cap)
    z = x - np.max(x)
    e = np.exp(z)
    return e / (e.sum() + 1e-12)

def _eta_t(eta0: float, t: int, sched: str):
    if sched == "const": return eta0
    if sched == "sqrt": return eta0 / max(1.0, math.sqrt(t))
    if sched == "harmonic": return eta0 / (1.0 + 0.5*t)
    return eta0

def bernoulli_kl(p, q, eps=1e-6):
    p = torch.clamp(p, eps, 1.0-eps)
    q = torch.clamp(q, eps, 1.0-eps)
    return p * torch.log(p/q) + (1.0-p) * torch.log((1.0-p)/(1.0-q))



CARDS = [0, 1, 2]

def ev_p1_vs_numpy(p1_vec6, p2_vec6):
    b1, c1 = p1_vec6[:3], p1_vec6[3:]
    c2, b2 = p2_vec6[:3], p2_vec6[3:]
    ev = 0.0
    for c1i in CARDS:
        for c2i in CARDS:
            if c1i == c2i: continue
            s4 = 2.0 if c1i > c2i else -2.0
            s2 = 1.0 if c1i > c2i else -1.0
            term_B = b1[c1i] * (c2[c2i]*s4 + (1.0-c2[c2i])*1.0)
            term_C = (1.0-b1[c1i]) * ( b2[c2i]*( c1[c1i]*s4 + (1.0-c1[c1i])*(-1.0) ) + (1.0-b2[c2i])*s2 )
            ev += term_B + term_C
    return ev / 6.0

def ev_p1_vs_torch(p1_probs6, p2_probs6):
    b1, c1 = p1_probs6[:3], p1_probs6[3:]
    c2, b2 = p2_probs6[:3], p2_probs6[3:]
    ev = torch.zeros([], dtype=torch.float32, device=p1_probs6.device)
    for c1i in CARDS:
        for c2i in CARDS:
            if c1i == c2i: continue
            s4 = 2.0 if c1i > c2i else -2.0
            s2 = 1.0 if c1i > c2i else -1.0
            term_B = b1[c1i] * (c2[c2i]*s4 + (1.0-c2[c2i])*1.0)
            term_C = (1.0-b1[c1i]) * ( b2[c2i]*( c1[c1i]*s4 + (1.0-c1[c1i])*(-1.0) ) + (1.0-b2[c2i])*s2 )
            ev = ev + term_B + term_C
    return ev / 6.0

def _solve_nash_zerosum_mwu(U: np.ndarray, steps: int = 4000, eta: float = 0.8):
    U = np.asarray(U, dtype=np.float64)
    n, m = U.shape
    if n == 0 or m == 0:
        raise ValueError("U must be non-empty.")
    if n == 1 and m == 1:
        return np.array([1.0], dtype=np.float64), np.array([1.0], dtype=np.float64)

    lx = np.zeros(n, dtype=np.float64)
    ly = np.zeros(m, dtype=np.float64)

    for t in range(1, steps + 1):
        x = _softmax_np(lx, cap=50.0)
        y = _softmax_np(ly, cap=50.0)

        gx = U @ y
        gy = -(x @ U)

        et = float(eta) / math.sqrt(t)

        lx = np.log(x + 1e-12) + et * gx
        ly = np.log(y + 1e-12) + et * gy

        lx -= lx.max()
        ly -= ly.max()

    x = _softmax_np(lx, cap=50.0)
    y = _softmax_np(ly, cap=50.0)
    return x, y




def _enumerate_pure_policies_p1():
    out = []
    for mask in range(64):
        b = [(mask>>0)&1, (mask>>1)&1, (mask>>2)&1]
        c = [(mask>>3)&1, (mask>>4)&1, (mask>>5)&1]
        out.append(np.array(b+c, dtype=np.float64))
    return out

def _enumerate_pure_policies_p2():
    out = []
    for mask in range(64):
        c = [(mask>>0)&1, (mask>>1)&1, (mask>>2)&1]
        b = [(mask>>3)&1, (mask>>4)&1, (mask>>5)&1]
        out.append(np.array(c+b, dtype=np.float64))
    return out

PURE1 = _enumerate_pure_policies_p1()
PURE2 = _enumerate_pure_policies_p2()



def _init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight, gain=1.0)
        nn.init.zeros_(m.bias)

class CondPop(nn.Module):
    def __init__(self, max_k: int, hidden: int = 128):
        super().__init__()
        self.max_k = max_k
        self.net = nn.Sequential(
            nn.Linear(max_k, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, 6)
        )
        self.apply(_init_weights)

    def probs_from_sigma(self, sigma_row: torch.Tensor, tau: float):
        logits = self.net(sigma_row) / max(1e-8, tau)
        return torch.sigmoid(logits)



class NeuPLRunner:
    def __init__(self, args, seed: int, device: torch.device):
        self.args = args
        self.seed = seed
        self.device = device
        self._setup_state()


    def _setup_state(self):
        _seed_everything(self.seed)

        self.max_k = int(self.args.kmax)
        self.K1 = 1
        self.K2 = 1

        self.SIGMA1 = np.zeros((self.max_k, self.max_k), dtype=np.float64)
        self.SIGMA2 = np.zeros((self.max_k, self.max_k), dtype=np.float64)
        self.SIGMA1[0, 0] = 1.0
        self.SIGMA2[0, 0] = 1.0

        self.pop1 = CondPop(self.max_k).to(self.device)
        self.pop2 = CondPop(self.max_k).to(self.device)

        self.opt = torch.optim.Adam(
            list(self.pop1.parameters()) + list(self.pop2.parameters()),
            lr=self.args.abr_lr
        )

        self._prev_pi1 = None
        self._prev_pi2 = None


    def _row_to_sigma_tensor(self, row_np: np.ndarray, active_len: int) -> torch.Tensor:
        s = row_np[:active_len].copy()
        z = s.sum()
        if z <= 0:
            s[:] = 0.0
            s[0] = 1.0
        else:
            s /= z

        out = np.zeros(self.max_k, dtype=np.float32)
        out[:active_len] = s.astype(np.float32)
        return torch.tensor(out, device=self.device, dtype=torch.float32)

    def _unique_rows(self, S: np.ndarray, K: int):
        seen = set()
        uniq = []
        for i in range(K):
            r = S[i, :K].copy()
            z = r.sum()
            if z <= 0:
                r[:] = 0.0
                r[0] = 1.0
            else:
                r /= z
            key = tuple(np.round(r, 8)) 
            if key not in seen:
                seen.add(key)
                uniq.append((i, r))
        return uniq


    @torch.no_grad()
    def policy_mats_from_graph(self):
        P1 = []
        P2 = []

        for i in range(self.K1):
            sigma1 = self._row_to_sigma_tensor(self.SIGMA1[i], active_len=self.K2)
            pi1 = self.pop1.probs_from_sigma(sigma1, self.args.tau)
            pi1 = torch.clamp(torch.nan_to_num(pi1, nan=0.5), self.args.prob_eps, 1.0 - self.args.prob_eps)
            P1.append(pi1)

        for j in range(self.K2):
            sigma2 = self._row_to_sigma_tensor(self.SIGMA2[j], active_len=self.K1)
            pi2 = self.pop2.probs_from_sigma(sigma2, self.args.tau)
            pi2 = torch.clamp(torch.nan_to_num(pi2, nan=0.5), self.args.prob_eps, 1.0 - self.args.prob_eps)
            P2.append(pi2)

        return torch.stack(P1, 0), torch.stack(P2, 0)

    @torch.no_grad()
    def ev_matrix(self):
        P1, P2 = self.policy_mats_from_graph()
        P1n = P1.detach().cpu().numpy().astype(np.float64)
        P2n = P2.detach().cpu().numpy().astype(np.float64)

        U = np.zeros((self.K1, self.K2), dtype=np.float64)
        for i in range(self.K1):
            for j in range(self.K2):
                U[i, j] = ev_p1_vs_numpy(P1n[i], P2n[j])
        return U, P1n, P2n


    def build_sigma_psro_nash(self, U: np.ndarray):
        self.SIGMA1[:] = 0.0
        self.SIGMA2[:] = 0.0
        self.SIGMA1[0, 0] = 1.0
        self.SIGMA2[0, 0] = 1.0

        K = min(self.K1, self.K2)
        for i in range(1, K):
            subU = U[:i, :i]
            x, y = _solve_nash_zerosum_mwu(
                subU,
                steps=max(1000, int(self.args.meta_loops) * 10),
                eta=float(self.args.eta)
            )
            self.SIGMA1[i, :i] = y
            self.SIGMA2[i, :i] = x

        for i in range(K, self.K1):
            self.SIGMA1[i, 0] = 1.0
        for j in range(K, self.K2):
            self.SIGMA2[j, 0] = 1.0


    def abr_train(self):
        self.pop1.train()
        self.pop2.train()

        uniq1 = self._unique_rows(self.SIGMA1, self.K1)
        uniq2 = self._unique_rows(self.SIGMA2, self.K2)

        for _ in range(int(self.args.abr_steps)):
            P1_all, P2_all = self.policy_mats_from_graph()
            P1_all_det = P1_all.detach()
            P2_all_det = P2_all.detach()

            loss = torch.zeros([], dtype=torch.float32, device=self.device)

            for _, row_np in uniq1:
                sigma1 = self._row_to_sigma_tensor(row_np, active_len=self.K2)
                pi1 = self.pop1.probs_from_sigma(sigma1, self.args.tau)
                pi1 = torch.clamp(torch.nan_to_num(pi1, nan=0.5), self.args.prob_eps, 1.0 - self.args.prob_eps)

                w = torch.tensor(row_np, dtype=torch.float32, device=self.device)
                mix2 = (w[:, None] * P2_all_det[:self.K2]).sum(0)

                loss = loss - ev_p1_vs_torch(pi1, mix2)

            for _, row_np in uniq2:
                sigma2 = self._row_to_sigma_tensor(row_np, active_len=self.K1)
                pi2 = self.pop2.probs_from_sigma(sigma2, self.args.tau)
                pi2 = torch.clamp(torch.nan_to_num(pi2, nan=0.5), self.args.prob_eps, 1.0 - self.args.prob_eps)

                w = torch.tensor(row_np, dtype=torch.float32, device=self.device)
                mix1 = (w[:, None] * P1_all_det[:self.K1]).sum(0)

                loss = loss + ev_p1_vs_torch(mix1, pi2)

            if float(self.args.beta_kl) > 0.0:
                with torch.no_grad():
                    P1_prev, P2_prev = self.policy_mats_from_graph()
                P1_cur, P2_cur = self.policy_mats_from_graph()
                kl = bernoulli_kl(P1_cur, P1_prev, eps=float(self.args.prob_eps)).mean() \
                   + bernoulli_kl(P2_cur, P2_prev, eps=float(self.args.prob_eps)).mean()
                loss = loss + float(self.args.beta_kl) * kl

            self.opt.zero_grad(set_to_none=True)
            loss.backward()
            if self.args.clip_grad and self.args.clip_grad > 0:
                torch.nn.utils.clip_grad_norm_(
                    list(self.pop1.parameters()) + list(self.pop2.parameters()),
                    float(self.args.clip_grad)
                )
            self.opt.step()


    def _activate_new(self, role: int):
        if role == 0:
            if self.K1 >= self.max_k:
                return
            self.K1 += 1
            self.SIGMA1[self.K1 - 1, :] = 0.0
            self.SIGMA1[self.K1 - 1, 0] = 1.0
        else:
            if self.K2 >= self.max_k:
                return
            self.K2 += 1
            self.SIGMA2[self.K2 - 1, :] = 0.0
            self.SIGMA2[self.K2 - 1, 0] = 1.0


    def _nash_mixture_on_current_game(self, U: np.ndarray):
        x, y = _solve_nash_zerosum_mwu(U, steps=max(2000, int(self.args.meta_loops) * 10), eta=float(self.args.eta))
        return x, y

    def nashconv(self):
        U, P1, P2 = self.ev_matrix()
        x, y = self._nash_mixture_on_current_game(U)
        val = float(x @ U @ y)

        mix_p1 = sum(x[i] * P1[i] for i in range(self.K1))
        mix_p2 = sum(y[j] * P2[j] for j in range(self.K2))

        br1 = max(ev_p1_vs_numpy(pi, mix_p2) for pi in PURE1)
        br2min = min(ev_p1_vs_numpy(mix_p1, pj) for pj in PURE2)

        return max(0.0, br1 - br2min), val


    def run(self, csv_path: str, header=True):
        os.makedirs(os.path.dirname(csv_path) or ".", exist_ok=True)
        if header:
            print(f"[NeuPL] seed={self.seed} | device={self.device} | abr_steps={self.args.abr_steps} | kmax={self.max_k}")

        hist_nc, hist_val, hist_dt, hist_mem = [], [], [], []
        hist_n1, hist_n2 = [], []
        mem_type_rec = "n/a"

        with open(csv_path, "w", newline="") as fcsv:
            w = csv.writer(fcsv)
            w.writerow(["iter","timestamp","n_strats_p1","n_strats_p2","time_sec","mem_mb","mem_type","nashconv","mix_ev_p1"])

            for it in range(1, int(self.args.iters) + 1):
                t0 = time.time()

                self.abr_train()

                U, _, _ = self.ev_matrix()
                self.build_sigma_psro_nash(U)

                self._activate_new(0)
                self._activate_new(1)

                nc, val = self.nashconv()
                dt = time.time() - t0
                mem, mtype = _mem_mb()
                mem_type_rec = mtype

                hist_nc.append(nc); hist_val.append(val)
                hist_dt.append(dt); hist_mem.append(mem)
                hist_n1.append(self.K1); hist_n2.append(self.K2)

                print(f"[NeuPL] seed={self.seed} iter {it}/{self.args.iters} | P1={self.K1} P2={self.K2} | "
                      f"NashConv={nc:.6f} val={val:+.6f} | {dt:.2f}s mem={mem:.1f}MB")

                w.writerow([it, time.strftime("%Y-%m-%d %H:%M:%S"),
                            self.K1, self.K2, f"{dt:.3f}", f"{mem:.2f}", mtype,
                            f"{nc:.6f}", f"{val:+.6f}"])
                fcsv.flush()

        return {
            "nc": np.array(hist_nc, dtype=np.float64),
            "val": np.array(hist_val, dtype=np.float64),
            "dt": np.array(hist_dt, dtype=np.float64),
            "mem": np.array(hist_mem, dtype=np.float64),
            "n1": np.array(hist_n1, dtype=np.float64),
            "n2": np.array(hist_n2, dtype=np.float64),
            "mem_type": mem_type_rec,
        }



def _save_meanstd_csv(path, T, mem_type,
                      dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std,
                      n1_mean, n1_std, n2_mean, n2_std):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["iter",
                    "dt_mean","dt_std",
                    "mem_mb_mean","mem_mb_std",
                    "mix_ev_mean","mix_ev_std",
                    "nashconv_mean","nashconv_std",
                    "n_strats_p1_mean","n_strats_p1_std",
                    "n_strats_p2_mean","n_strats_p2_std",
                    "mem_type"])
        for i in range(T):
            w.writerow([i+1,
                        f"{dt_mean[i]:.6f}", f"{dt_std[i]:.6f}",
                        f"{mem_mean[i]:.6f}", f"{mem_std[i]:.6f}",
                        f"{ev_mean[i]:+.6f}", f"{ev_std[i]:.6f}",
                        f"{nc_mean[i]:.6f}", f"{nc_std[i]:.6f}",
                        f"{n1_mean[i]:.6f}", f"{n1_std[i]:.6f}",
                        f"{n2_mean[i]:.6f}", f"{n2_std[i]:.6f}",
                        mem_type])

def _read_final_nashconv(csv_path: str):
    with open(csv_path, "r") as f:
        rows = list(csv.reader(f))
    last = rows[-1]
    return float(last[7])

def _welch_ttest(a: np.ndarray, b: np.ndarray):
    if ttest_ind is None:
        return float("nan"), float("nan")
    t, p = ttest_ind(a, b, equal_var=False)
    return float(t), float(p)



def main():
    args = parse_args()
    device = _pick_device(args.device)
    seeds = [int(x.strip()) for x in args.seeds.split(",") if x.strip() != ""]

    os.makedirs(args.outdir, exist_ok=True)

    per = []
    seed_csvs = []
    for s in seeds:
        r = NeuPLRunner(args, seed=s, device=device)
        csv_path = os.path.join(args.outdir, f"{args.csv_base}_seed{s}.csv")
        seed_csvs.append(csv_path)
        per.append(r.run(csv_path, header=(s == seeds[0])))

    T = int(args.iters)
    for rec in per:
        assert len(rec["nc"]) == T and len(rec["val"]) == T

    nc_stack = np.stack([rec["nc"] for rec in per], 0)
    ev_stack = np.stack([rec["val"] for rec in per], 0)
    dt_stack = np.stack([rec["dt"] for rec in per], 0)
    mem_stack = np.stack([rec["mem"] for rec in per], 0)
    n1_stack = np.stack([rec["n1"] for rec in per], 0)
    n2_stack = np.stack([rec["n2"] for rec in per], 0)

    nc_mean, nc_std = nc_stack.mean(0), nc_stack.std(0, ddof=1)
    ev_mean, ev_std = ev_stack.mean(0), ev_stack.std(0, ddof=1)
    dt_mean, dt_std = dt_stack.mean(0), dt_stack.std(0, ddof=1)
    mem_mean, mem_std = mem_stack.mean(0), mem_stack.std(0, ddof=1)
    n1_mean, n1_std = n1_stack.mean(0), n1_stack.std(0, ddof=1)
    n2_mean, n2_std = n2_stack.mean(0), n2_stack.std(0, ddof=1)

    agg_csv = os.path.join(args.outdir, f"{args.csv_base}_meanstd.csv")
    _save_meanstd_csv(agg_csv, T, per[0]["mem_type"],
                      dt_mean, dt_std, mem_mean, mem_std,
                      ev_mean, ev_std, nc_mean, nc_std,
                      n1_mean, n1_std, n2_mean, n2_std)
    print(f"[AGG] wrote mean/std CSV: {agg_csv}")
    print(f"[AGG] final (iter {T}) — EV mean={ev_mean[-1]:+.4f} ± {ev_std[-1]:.4f} | "
          f"NashConv mean={nc_mean[-1]:.4f} ± {nc_std[-1]:.4f} | "
          f"P1 strats ~{n1_mean[-1]:.2f}, P2 strats ~{n2_mean[-1]:.2f}")

    if args.ttest_against_glob is not None:
        base_paths = sorted(glob.glob(args.ttest_against_glob))
        if len(base_paths) != len(seed_csvs):
            print(f"[TTEST] WARNING: found {len(base_paths)} baseline CSVs, expected {len(seed_csvs)}. "
                  f"Continuing with min-count.")
        n = min(len(base_paths), len(seed_csvs))
        a = np.array([_read_final_nashconv(p) for p in seed_csvs[:n]], dtype=np.float64)
        b = np.array([_read_final_nashconv(p) for p in base_paths[:n]], dtype=np.float64)
        t, pval = _welch_ttest(a, b)
        print(f"[TTEST] Welch t-test on final NashConv (NeuPL vs baseline): "
              f"NeuPL mean={a.mean():.4f}±{a.std(ddof=1):.4f}, "
              f"baseline mean={b.mean():.4f}±{b.std(ddof=1):.4f} | t={t:.4f}, p={pval:.4g}")

    if not args.no_plots:
        its = np.arange(1, T+1)
        plt.figure()
        plt.plot(its, nc_mean, label="NashConv (mean)")
        plt.fill_between(its, nc_mean-nc_std, nc_mean+nc_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Exploitability (NeuPL) — mean ± std")

        plt.figure()
        plt.plot(its, ev_mean, label="E[P1 payoff] (mean)")
        plt.fill_between(its, ev_mean-ev_std, ev_mean+ev_std, alpha=0.25, label="±1 std")
        plt.grid(True); plt.legend(); plt.xlabel("iter"); plt.title("Kuhn – Mixture Value (NeuPL) — mean ± std")

    nash_path_png = os.path.join(args.outdir, f"{args.csv_base}_nashconv_meanstd.png")
    ev_path_png   = os.path.join(args.outdir, f"{args.csv_base}_mixev_meanstd.png")

    figs = [plt.figure(n) for n in plt.get_fignums()]
    if len(figs) >= 2:
        figs[-2].tight_layout()
        figs[-2].savefig(nash_path_png, dpi=200)
        plt.close(figs[-2])

        figs[-1].tight_layout()
        figs[-1].savefig(ev_path_png, dpi=200)
        plt.close(figs[-1])
    else:
        plt.gcf().tight_layout()
        plt.gcf().savefig(nash_path_png, dpi=200)
        plt.close(plt.gcf())

    print(f"[PLOTS] saved: {nash_path_png}")
    print(f"[PLOTS] saved: {ev_path_png}")



if __name__ == "__main__":
    main()
