import argparse
import time
import csv
import sys
from copy import deepcopy
from typing import Dict, Any, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# ---------------------------------------------------------------------------
#                           Argument parser
# ---------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="GEMS run on Deceptive Messages"
    )

    # Bandit & env
    p.add_argument("--K",            type=int,   default=5,  help="Number of arms")
    p.add_argument("--M",            type=int,   default=3,  help="Message dimension")
    p.add_argument("--arms-means",   type=float, nargs="+",
                   default=[0.2, 0.5, 0.8, 0.4, 0.1],
                   metavar="mew", help="Space-separated list of arm means")
    p.add_argument("--bad-arm-target", type=int, default=0,
                   help="Arm index sender wants receiver to choose")
    p.add_argument("--top-seed",     type=int,   default=0,  help="(Legacy single-run) RNG seed")

    # Main loop
    p.add_argument("--iters",          type=int,   default=6,
                   help="Outer PSRO iterations")
    p.add_argument("--meta-eval-n",    type=int,   default=64,
                   help="Per-anchor MC samples for value estimation under current mix")
    p.add_argument("--eb-ucb-n",       type=int,   default=32,
                   help="Rollouts per candidate when scoring EB-UCB")
    p.add_argument("--candidate-pool-size", type=int, default=32)

    # Oracle / ABR-TR
    p.add_argument("--oracle-steps", type=int,   default=400,
                   help="(kept for parity; not used explicitly)")
    p.add_argument("--oracle-lr",    type=float, default=1e-3,
                   help="(kept for parity; not used explicitly)")
    p.add_argument("--abr-steps",    type=int,   default=200)
    p.add_argument("--abr-lr",       type=float, default=1e-3)
    p.add_argument("--adv-ema",      type=float, default=0.1,
                   help="Advantage baseline EMA coefficient (0<ema<=1)")

    # Misc knobs
    p.add_argument("--jac-eps",     type=float, default=1e-3)
    p.add_argument("--lambda-jac",  type=float, default=1e-2)
    p.add_argument("--delta0",      type=float, default=0.1)
    p.add_argument("--mwu-eta",     type=float, default=0.1,
                   help="Learning rate for OMWU meta updates")
    p.add_argument("--inner-eval",  type=int,   default=2)
    p.add_argument("--beta-kl",     type=float, default=1e-2)
    p.add_argument("--zdim",        type=int,   default=8,
                   help="Latent dimension fed to generators")

    # Device & logging
    p.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
    p.add_argument("--log-csv", default=None,
                   help="Filename for CSV instrumentation log "
                        "(default gems_<seed>.csv)")

    # Latent-space ablation (NEW)
    p.add_argument("--log-latent", action="store_true",
                   help="Log latent geometry & Jacobian metrics per iteration")
    p.add_argument("--latent-max-points", type=int, default=256,
                   help="Max anchors to use when computing geometry metrics")
    p.add_argument("--save-latent-npz", default=None,
                   help="If set, save anchors & PCA traces to this .npz")

    # -------- Multi-seed controls --------
    p.add_argument("--seeds", type=int, nargs="+",
                   default=[0, 1, 2, 3, 4],
                   help="Space-separated RNG seeds to run, e.g. --seeds 0 1 2 3 4")
    p.add_argument("--no-plot", action="store_true",
                   help="Disable matplotlib plots (recommended for multi-seed)")
    p.add_argument("--agg-prefix", default="gems_agg",
                   help="Prefix for aggregate CSV (mean/std)")

    return p.parse_args()

# ---------------------------------------------------------------------------
#                           Helper: RAM usage
# ---------------------------------------------------------------------------

try:
    import psutil
    _proc = psutil.Process()
    def get_ram_mb():
        return _proc.memory_info().rss / (1024 ** 2)
except Exception:
    try:
        import resource
        if sys.platform == "darwin":
            def get_ram_mb():
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024 ** 2)
        else:
            def get_ram_mb():
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0
    except Exception:
        def get_ram_mb():
            return float("nan")

# ---------------------------------------------------------------------------
#                           Small nets
# ---------------------------------------------------------------------------

class GeneratorSender(nn.Module):
    def __init__(self, K, M, zdim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(K + zdim, 64),
                                 nn.ReLU(),
                                 nn.Linear(64, M))
    def forward(self, obs, z):
        return self.net(torch.cat([obs, z], dim=-1))

class GeneratorReceiver(nn.Module):
    def __init__(self, K, M, zdim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(M + zdim, 64),
                                 nn.ReLU(),
                                 nn.Linear(64, K))
    def forward(self, obs, z):
        return self.net(torch.cat([obs, z], dim=-1))

# ---------------------------------------------------------------------------
#                           Core helpers
# ---------------------------------------------------------------------------

def sample_once(sgen, rgen, z_s, z_r, BEST_ARM, ARMS_MEANS, BAD_ARM_TARGET, device):
    if isinstance(z_s, np.ndarray):
        z_s = torch.tensor(z_s, dtype=torch.float32, device=device)
    else:
        z_s = z_s.to(device).float()
    if isinstance(z_r, np.ndarray):
        z_r = torch.tensor(z_r, dtype=torch.float32, device=device)
    else:
        z_r = z_r.to(device).float()

    s_obs = torch.zeros(len(ARMS_MEANS), device=device)
    s_obs[BEST_ARM] = 1.0
    s_logits = sgen(s_obs.unsqueeze(0), z_s.unsqueeze(0)).squeeze(0)
    msg = torch.distributions.Categorical(logits=s_logits).sample()

    r_obs = torch.zeros(sgen.net[-1].out_features, device=device)
    r_obs[int(msg)] = 1.0
    r_logits = rgen(r_obs.unsqueeze(0), z_r.unsqueeze(0)).squeeze(0)
    act = torch.distributions.Categorical(logits=r_logits).sample()

    chosen = int(act)
    r_reward = float(np.random.rand() < ARMS_MEANS[chosen])
    s_reward = 1.0 if chosen == BAD_ARM_TARGET else 0.0

    s_logp = torch.log_softmax(s_logits, dim=-1)[msg]
    r_logp = torch.log_softmax(r_logits, dim=-1)[act]
    return s_reward, r_reward, s_logp, r_logp

def empirical_bernstein_ucb(mu, var, n, delta):
    if n <= 1:
        return mu + 1.0
    rad = np.sqrt(2 * var * np.log(2.0 / delta) / n) + 3.0 * np.log(2.0 / delta) / n
    return mu + rad

def approx_jacobian_penalty(gen, z, obs_inputs, eps, device):
    if isinstance(z, np.ndarray):
        z = torch.tensor(z, dtype=torch.float32, device=device)
    z = z.float().detach()
    zp = z + eps
    total = 0.0
    for obs in obs_inputs:
        l0 = gen(obs.unsqueeze(0), z.unsqueeze(0)).detach()
        l1 = gen(obs.unsqueeze(0), zp.unsqueeze(0)).detach()
        total += (l1 - l0).norm().item() / (eps + 1e-12)
    return total / len(obs_inputs)

def build_candidate_pool(anchor_list, zdim, pool_size):
    pool = []
    for _ in range(pool_size):
        if np.random.rand() < 0.5 and anchor_list:
            z = anchor_list[np.random.randint(len(anchor_list))] + 0.1 * np.random.randn(zdim)
        else:
            z = np.random.randn(zdim)
        pool.append(z.astype(np.float32))
    return pool

def alpha_for_delta(t, delta0):
    # kept as originally provided (you can swap to t^-2 if you want the analysis schedule)
    return delta0 / (t + 1.0)

def kl_between_logits(old_logits, new_logits):
    old_p = torch.softmax(old_logits, dim=-1)
    new_logp = torch.log_softmax(new_logits, dim=-1)
    return (old_p * (torch.log(old_p + 1e-12) - new_logp)).sum(dim=-1).mean()

# ---------------- Latent-space metrics (NumPy-only: no SciPy needed) ----

def _pairwise_dists(A):
    if len(A) < 2:
        return None
    X = A.astype(np.float64)
    G = X @ X.T
    sq = np.clip(np.diag(G)[:, None] + np.diag(G)[None, :] - 2.0 * G, 0.0, None)
    D = np.sqrt(sq, dtype=np.float64)
    return D

def pairwise_stats(A):
    if len(A) < 2:
        return dict(mean_pair_dist=float("nan"), min_nn_dist=float("nan"))
    D = _pairwise_dists(A)
    np.fill_diagonal(D, np.inf)
    mean_pair = np.mean(D[np.isfinite(D)])
    min_nn = np.mean(np.min(D, axis=1))
    return dict(mean_pair_dist=float(mean_pair), min_nn_dist=float(min_nn))

def cov_eigs_stats(A):
    if len(A) < 2:
        return dict(cov_eig_max=float("nan"), cov_eig_min=float("nan"), cov_cond=float("nan"))
    C = np.cov(A.T)
    try:
        w = np.linalg.eigvalsh(C)
        w = np.maximum(w, 0.0)
        lam_max = float(np.max(w))
        lam_min = float(np.min(w))
        cond = float(lam_max / (lam_min + 1e-12))
    except np.linalg.LinAlgError:
        lam_max = lam_min = cond = float("nan")
    return dict(cov_eig_max=lam_max, cov_eig_min=lam_min, cov_cond=cond)

def jacobian_metrics(gen, z_np, obs_tensors, device):
    z = torch.tensor(z_np, dtype=torch.float32, device=device, requires_grad=True).unsqueeze(0)
    frobs, specs, conds = [], [], []
    for obs in obs_tensors:
        obs = obs.to(device).unsqueeze(0)
        logits = gen(obs, z)
        C = logits.shape[-1]
        J_rows = []
        for k in range(C):
            grad = torch.autograd.grad(logits[0, k], z, retain_graph=True, create_graph=False)[0]
            J_rows.append(grad.detach().cpu().numpy())
        J = np.vstack(J_rows)
        try:
            s = np.linalg.svd(J, compute_uv=False)
            frob = float(np.linalg.norm(J, ord='fro'))
            spec = float(s.max() if s.size > 0 else 0.0)
            cond = float((s.max() / (s.min() + 1e-12)) if s.size > 1 else np.nan)
        except np.linalg.LinAlgError:
            frob, spec, cond = float("nan"), float("nan"), float("nan")
        frobs.append(frob); specs.append(spec); conds.append(cond)
    return dict(jac_frob=np.mean(frobs), jac_spec=np.mean(specs), jac_cond=np.nanmean(conds))

def mwu_entropy(p):
    p = np.asarray(p, dtype=np.float64)
    p = p / (p.sum() + 1e-12)
    H = float(-(p * np.log(p + 1e-12)).sum())
    eff = float(1.0 / (np.square(p).sum() + 1e-12))
    return H, eff

def pca_2d(A):
    if len(A) < 2:
        return np.zeros((len(A), 2), dtype=np.float32)
    X = A - A.mean(axis=0, keepdims=True)
    U, S, Vt = np.linalg.svd(X, full_matrices=False)
    Z = X @ Vt[:2].T
    return Z.astype(np.float32)

# ---------------------------------------------------------------------------
#                   Meta-game estimation + OMWU (NEW)
# ---------------------------------------------------------------------------

def estimate_anchor_values_sender(sgen, rgen, sender_anchors, receiver_anchors,
                                  receiver_mix, n_per_anchor, BEST_ARM, ARMS_MEANS,
                                  BAD_ARM_TARGET, device):
    """Unbiased MC estimate of E_{zr~σR}[ payoff_s(zs, zr) ] for each sender anchor."""
    vals = np.zeros(len(sender_anchors), dtype=np.float64)
    for i, zs in enumerate(sender_anchors):
        s = 0.0
        for _ in range(n_per_anchor):
            r_idx = np.random.choice(len(receiver_anchors), p=receiver_mix/receiver_mix.sum())
            zr = receiver_anchors[r_idx]
            s_r, _, *_ = sample_once(sgen, rgen, zs, zr, BEST_ARM, ARMS_MEANS, BAD_ARM_TARGET, device)
            s += s_r
        vals[i] = s / max(1, n_per_anchor)
    return vals

def estimate_anchor_values_receiver(sgen, rgen, sender_anchors, receiver_anchors,
                                    sender_mix, n_per_anchor, BEST_ARM, ARMS_MEANS,
                                    BAD_ARM_TARGET, device):
    """Unbiased MC estimate of E_{zs~σS}[ payoff_r(zs, zr) ] for each receiver anchor."""
    vals = np.zeros(len(receiver_anchors), dtype=np.float64)
    for j, zr in enumerate(receiver_anchors):
        s = 0.0
        for _ in range(n_per_anchor):
            s_idx = np.random.choice(len(sender_anchors), p=sender_mix/sender_mix.sum())
            zs = sender_anchors[s_idx]
            _, r_r, *_ = sample_once(sgen, rgen, zs, zr, BEST_ARM, ARMS_MEANS, BAD_ARM_TARGET, device)
            s += r_r
        vals[j] = s / max(1, n_per_anchor)
    return vals

def omwu_update(curr_vals, prev_vals, mix, eta):
    """Optimistic MWU: p_{t+1,i} ∝ p_{t,i} * exp(η * (2*v_t[i] - v_{t-1}[i]))"""
    curr_vals = np.asarray(curr_vals, dtype=np.float64)
    if prev_vals is None:
        prev_vals = np.zeros_like(curr_vals)
    logits = np.log(mix + 1e-12) + eta * (2.0 * curr_vals - prev_vals)
    logits -= logits.max()  # stability
    new = np.exp(logits)
    return new / new.sum(), curr_vals.copy()  # also return stored-as-prev for next iter

# ---------------------------------------------------------------------------
#                           Single seed runner
# ---------------------------------------------------------------------------

def run_single(args: argparse.Namespace, seed: int, show_plot: bool) -> Dict[str, Any]:
    """Runs one full experiment for a given seed. Returns per-iter stats and file names."""
    device = torch.device(args.device)

    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        try:
            torch.cuda.manual_seed_all(seed)
        except Exception:
            pass

    ARMS_MEANS = np.array(args.arms_means, dtype=np.float32)
    BEST_ARM = int(np.argmax(ARMS_MEANS))

    sgen = GeneratorSender(args.K, args.M, args.zdim).to(device)
    rgen = GeneratorReceiver(args.K, args.M, args.zdim).to(device)

    sender_anchors   = [np.random.randn(args.zdim).astype(np.float32) for _ in range(2)]
    receiver_anchors = [np.random.randn(args.zdim).astype(np.float32) for _ in range(2)]
    sender_mix   = np.ones(len(sender_anchors))   / len(sender_anchors)
    receiver_mix = np.ones(len(receiver_anchors)) / len(receiver_anchors)

    # OMWU keeps previous payoff-vector estimates
    sender_vals_prev   = None
    receiver_vals_prev = None

    # instrumentation
    hist_s_mean, hist_r_mean = [], []
    it_times, it_ram, it_ts = [], [], []
    it_S, it_R, it_delta = [], [], []

    # latent instrumentation (optional)
    latent_rows = []
    anchors_snapshots = []
    pca_traces = []

    # Advantage baselines (EMA)
    baseline_s = 0.0
    baseline_r = 0.0
    ema = float(args.adv_ema)

    for t in range(1, args.iters + 1):
        t0 = time.time()
        it_ts.append(time.strftime("%Y-%m-%d %H:%M:%S"))

        # --------- Meta-game estimation (sampling under current mixes) ---------
        n_per_anchor = int(args.meta_eval_n)
        sender_vals = estimate_anchor_values_sender(
            sgen, rgen, sender_anchors, receiver_anchors, receiver_mix,
            n_per_anchor, BEST_ARM, ARMS_MEANS, args.bad_arm_target, device
        )
        receiver_vals = estimate_anchor_values_receiver(
            sgen, rgen, sender_anchors, receiver_anchors, sender_mix,
            n_per_anchor, BEST_ARM, ARMS_MEANS, args.bad_arm_target, device
        )

        # --------- OMWU updates (optimistic) ---------
        sender_mix, sender_vals_prev = omwu_update(sender_vals, sender_vals_prev, sender_mix, args.mwu_eta)
        receiver_mix, receiver_vals_prev = omwu_update(receiver_vals, receiver_vals_prev, receiver_mix, args.mwu_eta)

        delta = alpha_for_delta(t, args.delta0)

        # ----- EB-UCB for sender -----
        best_z, best_score = None, -1e9
        obs_inputs_s = [torch.tensor(np.eye(args.K)[BEST_ARM], dtype=torch.float32, device=device)]
        for zc in build_candidate_pool(sender_anchors, args.zdim, args.candidate_pool_size):
            rewards = []
            for _ in range(args.eb_ucb_n):
                r_idx = np.random.choice(len(receiver_anchors), p=receiver_mix/receiver_mix.sum())
                zr = receiver_anchors[r_idx]
                s_r, *_ = sample_once(
                    sgen, rgen, zc, zr, BEST_ARM, ARMS_MEANS,
                    args.bad_arm_target, device
                )
                rewards.append(s_r)
            mu = np.mean(rewards)
            var = np.var(rewards, ddof=1) if len(rewards) > 1 else 0.0
            ucb = empirical_bernstein_ucb(mu, var, len(rewards), delta)
            jac = approx_jacobian_penalty(sgen, zc, obs_inputs_s, args.jac_eps, device)
            score = ucb - args.lambda_jac * jac
            if score > best_score:
                best_score, best_z = score, zc
        sender_anchors.append(best_z)
        # extend mix with small mass on new anchor to avoid full reset
        sender_mix = np.append(sender_mix * 0.95, 0.05)
        sender_mix = sender_mix / sender_mix.sum()
        sender_vals_prev = np.append(sender_vals_prev if sender_vals_prev is not None else np.zeros(len(sender_mix)-1), 0.0)

        # ----- EB-UCB for receiver -----
        best_z, best_score = None, -1e9
        obs_inputs_r = [torch.tensor(np.eye(args.M)[i], dtype=torch.float32, device=device)
                        for i in range(args.M)]
        for zc in build_candidate_pool(receiver_anchors, args.zdim, args.candidate_pool_size):
            rewards = []
            for _ in range(args.eb_ucb_n):
                s_idx = np.random.choice(len(sender_anchors), p=sender_mix/sender_mix.sum())
                zs = sender_anchors[s_idx]
                _, r_r, *_ = sample_once(
                    sgen, rgen, zs, zc, BEST_ARM, ARMS_MEANS,
                    args.bad_arm_target, device
                )
                rewards.append(r_r)
            mu = np.mean(rewards)
            var = np.var(rewards, ddof=1) if len(rewards) > 1 else 0.0
            ucb = empirical_bernstein_ucb(mu, var, len(rewards), delta)
            jac = approx_jacobian_penalty(rgen, zc, obs_inputs_r, args.jac_eps, device)
            score = ucb - args.lambda_jac * jac
            if score > best_score:
                best_score, best_z = score, zc
        receiver_anchors.append(best_z)
        receiver_mix = np.append(receiver_mix * 0.95, 0.05)
        receiver_mix = receiver_mix / receiver_mix.sum()
        receiver_vals_prev = np.append(receiver_vals_prev if receiver_vals_prev is not None else np.zeros(len(receiver_mix)-1), 0.0)

        # -------- Latent logging (optional) --------
        if args.log_latent:
            maxN = args.latent_max_points
            Sa = np.array(sender_anchors, dtype=np.float32)
            Ra = np.array(receiver_anchors, dtype=np.float32)
            if len(Sa) > maxN:
                Sa = Sa[np.random.choice(len(Sa), maxN, replace=False)]
            if len(Ra) > maxN:
                Ra = Ra[np.random.choice(len(Ra), maxN, replace=False)]

            S_pair = pairwise_stats(Sa); S_cov = cov_eigs_stats(Sa)
            R_pair = pairwise_stats(Ra); R_cov = cov_eigs_stats(Ra)

            s_idx = np.random.randint(len(Sa))
            r_idx = np.random.randint(len(Ra))
            S_jac = jacobian_metrics(
                sgen,
                Sa[s_idx],
                [torch.tensor(np.eye(args.K)[BEST_ARM], dtype=torch.float32)],
                device
            )
            R_jac = jacobian_metrics(
                rgen,
                Ra[r_idx],
                [torch.tensor(np.eye(args.M)[i], dtype=torch.float32) for i in range(args.M)],
                device
            )

            Hs, effS = mwu_entropy(sender_mix)
            Hr, effR = mwu_entropy(receiver_mix)

            S_drift = float("nan"); R_drift = float("nan")
            if t > 1:
                s_new = np.array(sender_anchors[-1], dtype=np.float32)
                S_prev = np.array(sender_anchors[:-1], dtype=np.float32)
                r_new = np.array(receiver_anchors[-1], dtype=np.float32)
                R_prev = np.array(receiver_anchors[:-1], dtype=np.float32)
                if len(S_prev) > 0:
                    S_drift = float(np.min(np.linalg.norm(S_prev - s_new, axis=1)))
                if len(R_prev) > 0:
                    R_drift = float(np.min(np.linalg.norm(R_prev - r_new, axis=1)))

            S_pca = pca_2d(np.array(sender_anchors, dtype=np.float32))
            R_pca = pca_2d(np.array(receiver_anchors, dtype=np.float32))
            anchors_snapshots.append({"S": np.array(sender_anchors, dtype=np.float32),
                                      "R": np.array(receiver_anchors, dtype=np.float32)})
            pca_traces.append({"S": S_pca, "R": R_pca})

            latent_rows.append({
                "iter": t,
                # Sender geometry
                "S_mean_pair": S_pair["mean_pair_dist"], "S_min_nn": S_pair["min_nn_dist"],
                "S_cov_max": S_cov["cov_eig_max"], "S_cov_min": S_cov["cov_eig_min"], "S_cov_cond": S_cov["cov_cond"],
                # Sender Jacobian
                "S_jac_frob": S_jac["jac_frob"], "S_jac_spec": S_jac["jac_spec"], "S_jac_cond": S_jac["jac_cond"],
                # Receiver geometry
                "R_mean_pair": R_pair["mean_pair_dist"], "R_min_nn": R_pair["min_nn_dist"],
                "R_cov_max": R_cov["cov_eig_max"], "R_cov_min": R_cov["cov_eig_min"], "R_cov_cond": R_cov["cov_cond"],
                # Receiver Jacobian
                "R_jac_frob": R_jac["jac_frob"], "R_jac_spec": R_jac["jac_spec"], "R_jac_cond": R_jac["jac_cond"],
                # Mix usage
                "S_mix_entropy": Hs, "S_eff_support": effS,
                "R_mix_entropy": Hr, "R_eff_support": effR,
                # Drift
                "S_new_anchor_drift": S_drift, "R_new_anchor_drift": R_drift,
                # Bookkeeping
                "S_count": len(sender_anchors), "R_count": len(receiver_anchors),
                "delta": delta, "lambda_jac": args.lambda_jac, "beta_kl": args.beta_kl
            })

        # ----- ABR-TR (advantage-based with KL TR) -----
        snap_s = deepcopy(sgen.state_dict())
        snap_r = deepcopy(rgen.state_dict())
        opt_s  = optim.Adam(sgen.parameters(), lr=args.abr_lr)
        opt_r  = optim.Adam(rgen.parameters(), lr=args.abr_lr)

        s_obs_best = torch.tensor(np.eye(args.K)[BEST_ARM], dtype=torch.float32, device=device).unsqueeze(0)
        r_obs_e0   = torch.tensor(np.eye(args.M)[0],         dtype=torch.float32, device=device).unsqueeze(0)

        for _ in range(args.abr_steps):
            zs = sender_anchors[np.random.choice(len(sender_anchors), p=sender_mix/sender_mix.sum())]
            zr = receiver_anchors[np.random.choice(len(receiver_anchors), p=receiver_mix/receiver_mix.sum())]
            s_r, r_r, s_logp, r_logp = sample_once(
                sgen, rgen, zs, zr, BEST_ARM, ARMS_MEANS,
                args.bad_arm_target, device
            )

            # Update EMA baselines (reward scale is [0,1])
            baseline_s = (1.0 - ema) * baseline_s + ema * s_r
            baseline_r = (1.0 - ema) * baseline_r + ema * r_r
            adv_s = s_r - baseline_s
            adv_r = r_r - baseline_r

            # sender (KL TR vs snapshot)
            opt_s.zero_grad()
            old_s = GeneratorSender(args.K, args.M, args.zdim).to(device)
            old_s.load_state_dict(snap_s)
            kl_s = kl_between_logits(
                old_s(s_obs_best, torch.tensor(zs, dtype=torch.float32, device=device).unsqueeze(0)).detach(),
                sgen(s_obs_best, torch.tensor(zs, dtype=torch.float32, device=device).unsqueeze(0))
            )
            loss_s = -s_logp * adv_s + args.beta_kl * kl_s
            loss_s.backward()
            opt_s.step()

            # receiver (KL TR vs snapshot)
            opt_r.zero_grad()
            old_r = GeneratorReceiver(args.K, args.M, args.zdim).to(device)
            old_r.load_state_dict(snap_r)
            kl_r = kl_between_logits(
                old_r(r_obs_e0, torch.tensor(zr, dtype=torch.float32, device=device).unsqueeze(0)).detach(),
                rgen(r_obs_e0, torch.tensor(zr, dtype=torch.float32, device=device).unsqueeze(0))
            )
            loss_r = -r_logp * adv_r + args.beta_kl * kl_r
            loss_r.backward()
            opt_r.step()

        # ----- evaluation & logs -----
        # (For logging only, we compute a dense evaluation; meta-logic used sampling above.)
        # Evaluate mean payoff over anchor pairs with small n for visibility.
        eval_n = max(8, args.inner_eval)
        with torch.no_grad():
            S = len(sender_anchors); R = len(receiver_anchors)
            sp2 = np.zeros((S, R)); rp2 = np.zeros((S, R))
            for i, zs in enumerate(sender_anchors):
                for j, zr in enumerate(receiver_anchors):
                    ss = rr = 0.0
                    for _ in range(eval_n):
                        s_r, r_r, *_ = sample_once(sgen, rgen, zs, zr, BEST_ARM, ARMS_MEANS, args.bad_arm_target, device)
                        ss += s_r; rr += r_r
                    sp2[i, j] = ss / eval_n
                    rp2[i, j] = rr / eval_n

        hist_s_mean.append(sp2.mean())
        hist_r_mean.append(rp2.mean())

        it_times.append(time.time() - t0)
        it_ram.append(get_ram_mb())
        it_S.append(len(sender_anchors))
        it_R.append(len(receiver_anchors))
        it_delta.append(delta)

        print(f"[seed {seed}] iter {t:2d} | S={it_S[-1]:2d} R={it_R[-1]:2d} "
              f"| sender {hist_s_mean[-1]:.4f} receiver {hist_r_mean[-1]:.4f}")

    # -----------------------------------------------------------------------
    #                           Visualise & CSV
    # -----------------------------------------------------------------------
    if show_plot:
        plt.figure()
        plt.plot(hist_s_mean, label="sender mean")
        plt.plot(hist_r_mean, label="receiver mean")
        plt.legend(); plt.grid(True); plt.title(f"Seed {seed}")
        plt.show()

    csv_file = args.log_csv or f"gems_{seed}.csv"
    with open(csv_file, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["iter", "timestamp", "sender_mean", "receiver_mean",
                    "time_sec", "ram_mb", "seed", "S", "R", "delta",
                    "candidate_pool_size", "abr_steps"])
        for i in range(args.iters):
            w.writerow([i+1, it_ts[i], hist_s_mean[i], hist_r_mean[i],
                        it_times[i], it_ram[i], seed,
                        it_S[i], it_R[i], it_delta[i],
                        args.candidate_pool_size, args.abr_steps])
    print(f"\n[seed {seed}] Saved instrumentation log to {csv_file}")

    # ---- Save latent logs / NPZ (optional) ----
    latent_csv = None
    if args.log_latent and len(latent_rows) > 0:
        latent_csv = f"gems_latent_{seed}.csv"
        with open(latent_csv, "w", newline="") as f:
            w = csv.DictWriter(f, fieldnames=list(latent_rows[0].keys()))
            w.writeheader()
            for row in latent_rows:
                w.writerow(row)
        print(f"[seed {seed}] Saved latent metrics to {latent_csv}")

        if args.save_latent_npz:
            anchors_S = [d["S"] for d in anchors_snapshots]
            anchors_R = [d["R"] for d in anchors_snapshots]
            pca_S = [d["S"] for d in pca_traces]
            pca_R = [d["R"] for d in pca_traces]
            np.savez_compressed(
                args.save_latent_npz,
                anchors_S=anchors_S,
                anchors_R=anchors_R,
                pca_S=pca_S,
                pca_R=pca_R
            )
            print(f"[seed {seed}] Saved latent snapshots to {args.save_latent_npz}")

    return dict(
        seed=seed,
        sender_mean=np.array(hist_s_mean, dtype=np.float64),
        receiver_mean=np.array(hist_r_mean, dtype=np.float64),
        time_sec=np.array(it_times, dtype=np.float64),
        ram_mb=np.array(it_ram, dtype=np.float64),
        csv_file=csv_file,
        latent_csv=latent_csv
    )

# ---------------------------------------------------------------------------
#                       Aggregate writer (same as before)
# ---------------------------------------------------------------------------

def aggregate_runs(per_seed_stats: List[Dict[str, Any]], iters: int, outfile: str) -> None:
    def stack(key):
        return np.stack([d[key] for d in per_seed_stats], axis=0)

    S_sender = stack("sender_mean")
    S_receiver = stack("receiver_mean")
    S_time = stack("time_sec")
    S_ram = stack("ram_mb")

    agg = {
        "sender_mean_mu": S_sender.mean(axis=0),
        "sender_mean_std": S_sender.std(axis=0, ddof=1) if S_sender.shape[0] > 1 else np.zeros(iters),
        "receiver_mean_mu": S_receiver.mean(axis=0),
        "receiver_mean_std": S_receiver.std(axis=0, ddof=1) if S_receiver.shape[0] > 1 else np.zeros(iters),
        "time_sec_mu": S_time.mean(axis=0),
        "time_sec_std": S_time.std(axis=0, ddof=1) if S_time.shape[0] > 1 else np.zeros(iters),
        "ram_mb_mu": S_ram.mean(axis=0),
        "ram_mb_std": S_ram.std(axis=0, ddof=1) if S_ram.shape[0] > 1 else np.zeros(iters),
    }

    with open(outfile, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "iter",
            "sender_mean_mu", "sender_mean_std",
            "receiver_mean_mu", "receiver_mean_std",
            "time_sec_mu", "time_sec_std",
            "ram_mb_mu", "ram_mb_std",
            "n_seeds"
        ])
        for i in range(iters):
            w.writerow([
                i+1,
                float(agg["sender_mean_mu"][i]), float(agg["sender_mean_std"][i]),
                float(agg["receiver_mean_mu"][i]), float(agg["receiver_mean_std"][i]),
                float(agg["time_sec_mu"][i]), float(agg["time_sec_std"][i]),
                float(agg["ram_mb_mu"][i]), float(agg["ram_mb_std"][i]),
                len(per_seed_stats)
            ])
    print(f"\nSaved aggregate (mean/std across seeds) to {outfile}")

# ---------------------------------------------------------------------------
#                           Entrypoint
# ---------------------------------------------------------------------------

def main():
    args = parse_args()

    seeds = args.seeds if args.seeds else [args.top_seed]
    show_plot = (not args.no_plot)

    per_seed_stats = []
    for s in seeds:
        per_seed_stats.append(run_single(args, seed=int(s), show_plot=(show_plot and len(seeds) == 1)))

    if len(per_seed_stats) > 1:
        smin, smax = min(seeds), max(seeds)
        agg_out = f"{args.agg_prefix}_{smin}-{smax}.csv"
        iters = len(per_seed_stats[0]["sender_mean"])
        assert all(len(d["sender_mean"]) == iters for d in per_seed_stats), "Mismatched iters across runs"
        aggregate_runs(per_seed_stats, iters, agg_out)

if __name__ == "__main__":
    main()
