#!/usr/bin/env python3
import argparse, time, csv, os, random, math
import numpy as np
import torch, torch.nn as nn
import imageio.v2 as imageio

def parse_args():
    p = argparse.ArgumentParser("Chess-only GEMS+PPO (GIF + CSV)")
    p.add_argument("--iters", type=int, default=50)
    p.add_argument("--zdim", type=int, default=8)
    p.add_argument("--seed", type=int, default=0)

    p.add_argument("--rollout_min_steps", type=int, default=1600)
    p.add_argument("--ppo_epochs", type=int, default=10)
    p.add_argument("--ppo_batch", type=int, default=1024)
    p.add_argument("--gamma", type=float, default=0.99)
    p.add_argument("--gae_lambda", type=float, default=0.95)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--clip", type=float, default=0.2)
    p.add_argument("--ent_beta", type=float, default=1e-3)

    p.add_argument("--eta", type=float, default=0.35)
    p.add_argument("--eta_sched", type=str, default="sqrt", choices=["const","sqrt","harmonic"])
    p.add_argument("--mc_ni", type=int, default=8, help="episodes for vhat per iter")
    p.add_argument("--mc_B", type=int, default=16, help="episodes for rbar per iter")
    p.add_argument("--grow", type=float, default=0.0, help="sqrt(t) growth factor for MC budgets")

    p.add_argument("--pool_mut", type=int, default=2)
    p.add_argument("--pool_rand", type=int, default=1)
    p.add_argument("--oracle_nz", type=int, default=1)
    p.add_argument("--oracle_m", type=int, default=1)

    p.add_argument("--csv", type=str, default="gems_results.csv")
    p.add_argument("--video", type=str, default="gems_last.gif")
    p.add_argument("--fps", type=int, default=30)

    p.add_argument("--device", type=str, default="auto", choices=["auto","cuda","cpu"])
    return p.parse_args()

args = parse_args()

if "DISPLAY" not in os.environ:
    os.environ.setdefault("SDL_VIDEODRIVER", "dummy")

def _seed_everything(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def _init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight, gain=1.0)
        nn.init.zeros_(m.bias)

def viz_seed(it: int):
    return (SEED * 9973 + it * 7919) & 0x7fffffff

try:
    import psutil
    def _ram_mb():
        return psutil.Process().memory_info().rss / (1024**2)
except Exception:
    try:
        import resource, sys
        if sys.platform == "darwin":
            def _ram_mb():
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2)
        else:
            def _ram_mb():
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0
    except Exception:
        def _ram_mb():
            return float("nan")

def _mem_mb():
    try: return float(_ram_mb()), "rss"
    except Exception: return float("nan"), "n/a"

def softmax(x: np.ndarray) -> np.ndarray:
    z = x - x.max()
    e = np.exp(z)
    return e / (e.sum() + 1e-12)

def eta_t(eta0: float, t: int, sched: str) -> float:
    if sched == "const": return eta0
    if sched == "sqrt": return eta0 / max(1.0, math.sqrt(t))
    if sched == "harmonic": return eta0 / (1.0 + 0.5 * t)
    return eta0

def scale_count(base: int, grow: float, t: int) -> int:
    return max(1, int(round(base * (1.0 + grow * math.sqrt(max(1, t))))))

def empirical_sigma_rng(rng: np.random.Generator, pvec: np.ndarray, N: int):
    if N <= 1:
        return [int(rng.choice(len(pvec), p=pvec))]
    target = N * pvec
    base = np.floor(target).astype(int)
    rem = int(N - base.sum())
    if rem > 0:
        frac = target - base
        order = np.argsort(-frac + 1e-12 * np.arange(len(pvec))[::-1])
        for k in range(rem): base[order[k]] += 1
    seq = []
    for idx, cnt in enumerate(base.tolist()): seq.extend([idx] * cnt)
    if len(seq) == 0: seq = [int(rng.choice(len(pvec), p=pvec))]
    return seq

from pettingzoo.classic import chess_v6
def make_env(render=False, mode=None):
    return chess_v6.env(render_mode=(mode if render else None))

SEED = args.seed
_seed_everything(SEED)
_rng = np.random.default_rng(SEED)

env_probe = make_env(False, None)
env_probe.reset(seed=SEED)
AGENT_IDS = list(env_probe.agents)
o0 = env_probe.observe(AGENT_IDS[0])
OBS_DIM = int(o0["observation"].size)
try:
    from pettingzoo.classic.chess_v6 import raw_move_mapping
    ACT_DIM = len(raw_move_mapping)
except Exception:
    ACT_DIM = 4672
N_AGENTS = 2
ZDIM = args.zdim

def write_video(frames, path, fps):
    if not frames: return None
    gif_path = os.path.splitext(path)[0] + ".gif"
    imageio.mimsave(gif_path, frames, duration=1.0/max(fps,1))
    return gif_path

from gymnasium.spaces import Discrete

class CategoricalHead(nn.Module):
    def __init__(self, in_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.ReLU(),
            nn.Linear(32, act_dim)
        )
        self.apply(_init_weights)
    def forward(self, x): return self.net(x)

class VNet(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.ReLU(),
            nn.Linear(32, 1)
        )
        self.apply(_init_weights)
    def forward(self, x): return self.net(x).squeeze(-1)

class AC(nn.Module):
    def __init__(self, obs_dim, act_dim, zdim):
        super().__init__()
        in_dim = obs_dim + zdim
        self.pi = CategoricalHead(in_dim, act_dim)
        self.v  = VNet(in_dim)
    def forward(self, obs, z):
        x = torch.cat([obs, z], -1)
        logits = self.pi(x)
        return logits, self.v(x)

def mask_logits(logits, mask_t):
    if mask_t is None: return logits
    illegal = (mask_t <= 0.0)
    return logits.masked_fill(illegal, -1e9)

class Agent:
    def __init__(self, obs_dim, act_dim, zdim, lr):
        self.net = AC(obs_dim, act_dim, zdim).to(device)
        self.opt = torch.optim.Adam(self.net.parameters(), lr=lr)

    @torch.no_grad()
    def act(self, obs_np, z_np, mask_np=None):
        obs = torch.tensor(obs_np, dtype=torch.float32, device=device).unsqueeze(0)
        z   = torch.tensor(z_np,  dtype=torch.float32, device=device).unsqueeze(0)
        logits, v = self.net(obs, z)
        if mask_np is not None:
            mask_t = torch.tensor(mask_np, dtype=torch.float32, device=device).unsqueeze(0)
            logits = mask_logits(logits, mask_t)
        d = torch.distributions.Categorical(logits=logits)
        a = d.sample()
        return a.item(), d.log_prob(a).squeeze(0), d.entropy().squeeze(0), v.squeeze(0)

    def evaluate(self, obs_t, z_t, act_t, mask_t=None):
        logits, v = self.net(obs_t, z_t)
        if mask_t is not None: logits = mask_logits(logits, mask_t)
        d = torch.distributions.Categorical(logits=logits)
        logp = d.log_prob(act_t); ent = d.entropy()
        return logp, ent, v

if args.device == "auto":
    dev = "cuda" if torch.cuda.is_available() else "cpu"
else:
    dev = args.device
device = torch.device(dev)
print(f"[GEMS] env=chess_v6 agents=2 device={device.type}" +
      (f" gpu={torch.cuda.get_device_name(0)}" if device.type=='cuda' else ""))

ENT_BETA = args.ent_beta; GAMMA = args.gamma; LAMBDA = args.gae_lambda
PPO_EPOCHS, PPO_BATCH = args.ppo_epochs, args.ppo_batch
MIN_STEPS_PER_ABR = args.rollout_min_steps
GEMS_ITERS = args.iters

Z = [[] for _ in range(N_AGENTS)]
LOGS = []
G_PREV = []

def init_population():
    for p in range(N_AGENTS):
        z0 = np.random.normal(0, 1, size=(ZDIM,)).astype(np.float32)
        Z[p].append(z0)
        LOGS.append(np.array([0.0], dtype=np.float64))
        G_PREV.append(np.zeros(1, dtype=np.float64))

init_population()
agents = [Agent(OBS_DIM, ACT_DIM, ZDIM, args.lr) for _ in range(N_AGENTS)]

def sigma_list(): return [softmax(LOGS[p]) for p in range(N_AGENTS)]

def sample_profile_from(sigma_seq, rng):
    return [int(rng.choice(len(sigma_seq[p]), p=sigma_seq[p])) for p in range(N_AGENTS)]

def stratified_profile_batches(N, rng):
    sig = sigma_list()
    per_agent_lists = [empirical_sigma_rng(rng, sig[p], N) for p in range(N_AGENTS)]
    return [[per_agent_lists[p][i % len(per_agent_lists[p])] for p in range(N_AGENTS)] for i in range(N)]

def _flatten_obs_chess(obs_dict):
    o = obs_dict["observation"].astype(np.float32).flatten()
    m = obs_dict["action_mask"].astype(np.float32)
    return o, m

def run_episode(prof, z_override=None, render=False, seed=None, record_rgb=False):
    env = make_env(render=(render or record_rgb),
                   mode=("rgb_array" if (render or record_rgb) else None))
    env.reset(seed=seed if seed is not None else random.randint(0, 1<<30))
    frames, rets = [], np.zeros(N_AGENTS, dtype=np.float32)

    if record_rgb:
        f = env.render()
        if f is not None: frames.append(f)

    while env.agents:
        aid = env.agent_selection
        p = AGENT_IDS.index(aid)

        if env.terminations.get(aid, False) or env.truncations.get(aid, False):
            env.step(None)
        else:
            obs_dict = env.observe(aid)
            o, m = _flatten_obs_chess(obs_dict)
            z = z_override[p] if z_override is not None else Z[p][prof[p]]
            a, _, _, _ = agents[p].act(o, z, m)
            env.step(a)

        for k, r in env.rewards.items():
            rets[AGENT_IDS.index(k)] += float(r)

        if record_rgb:
            f = env.render()
            if f is not None: frames.append(f)

    env.close()
    return rets, frames

def record_episode(prof, z_override, seed, path, fps):
    rets, frames = run_episode(prof, z_override, render=True, seed=seed, record_rgb=True)
    out = write_video(frames, path, fps)
    return rets, out

def meta_estimate(it):
    ni_now = scale_count(args.mc_ni, args.grow, it)
    B_now  = scale_count(args.mc_B,  args.grow, it)
    vhat = [np.zeros(len(Z[p]), dtype=np.float64) for p in range(N_AGENTS)]
    vcnt = [np.zeros(len(Z[p]), dtype=np.int64)   for p in range(N_AGENTS)]
    rbar = np.zeros(N_AGENTS, dtype=np.float64)

    profiles = stratified_profile_batches(ni_now, _rng)
    for prof in profiles:
        zr = [Z[p][prof[p]] for p in range(N_AGENTS)]
        rets, _ = run_episode(prof, zr, render=False)
        for p in range(N_AGENTS):
            k = prof[p]; vhat[p][k] += rets[p]; vcnt[p][k] += 1
    for p in range(N_AGENTS):
        vhat[p] = vhat[p] / np.maximum(1, vcnt[p])

    profiles_B = stratified_profile_batches(B_now, _rng)
    for prof in profiles_B:
        zr = [Z[p][prof[p]] for p in range(N_AGENTS)]
        rets, _ = run_episode(prof, zr, render=False)
        rbar += rets
    rbar /= max(1, B_now)
    return vhat, rbar

def mwu_update_omwu(vhat, rbar, it):
    eta_now = eta_t(args.eta, it, args.eta_sched)
    for p in range(N_AGENTS):
        if LOGS[p].shape[0] != len(Z[p]):
            add = len(Z[p]) - LOGS[p].shape[0]
            if add > 0:
                new_logits = np.full(add, LOGS[p].min() - 5.0, dtype=np.float64)
                LOGS[p] = np.concatenate([LOGS[p], new_logits], axis=0)
                G_PREV[p] = np.concatenate([G_PREV[p], np.zeros(add, dtype=np.float64)], axis=0)
        gains = np.array(vhat[p], dtype=np.float64) - float(rbar[p])
        grad_eff = 2.0 * gains - G_PREV[p]
        LOGS[p] = LOGS[p] + eta_now * grad_eff
        G_PREV[p] = gains

def oracle_select(p, it):
    base = Z[p][-1]
    cand = []
    for _ in range(args.pool_mut):
        noise = np.random.normal(0, 0.25, size=(ZDIM,)).astype(np.float32)
        cand.append(base + noise)
    for _ in range(args.pool_rand):
        cand.append(np.random.normal(0, 1, size=(ZDIM,)).astype(np.float32))

    scores = []
    sig = sigma_list()
    for zc in cand:
        s_acc = 0.0
        for _ in range(args.oracle_m):
            prof = sample_profile_from(sig, _rng)
            zr = [zc if q == p else Z[q][prof[q]] for q in range(N_AGENTS)]
            rets, _ = run_episode(prof, zr, render=False)
            s_acc += rets[p]
        scores.append(s_acc / max(1, args.oracle_m))

    order = np.argsort(scores)[::-1]
    add_n = min(args.oracle_nz, len(order))
    if add_n > 0:
        for j in range(add_n): Z[p].append(cand[order[j]].copy())
        new_logits = np.full(add_n, LOGS[p].min() - 5.0, dtype=np.float64)
        LOGS[p] = np.concatenate([LOGS[p], new_logits], axis=0)
        G_PREV[p] = np.concatenate([G_PREV[p], np.zeros(add_n, dtype=np.float64)], axis=0)

ENT_BETA = args.ent_beta

def collect_rollouts(p, z_anchor):
    O, Zs, A, LP, R, ADV, MASKS = [], [], [], [], [], [], []
    steps = 0
    while steps < args.rollout_min_steps:
        env = make_env(False, None)
        env.reset(seed=random.randint(0, 1<<30))
        traj = []

        while env.agents:
            aid = env.agent_selection
            i = AGENT_IDS.index(aid)
            if env.terminations.get(aid, False) or env.truncations.get(aid, False):
                env.step(None)
            else:
                obs_dict = env.observe(aid)
                o, m = _flatten_obs_chess(obs_dict)
                z = z_anchor if i == p else Z[i][-1]
                a, lp, _, v = agents[i].act(o, z, m)
                env.step(a)
                if i == p:
                    r_p = float(env.rewards[aid])
                    traj.append([o, m, z, a, lp.item(), v.item(), r_p])

        if traj:
            vals = [x[5] for x in traj] + [0.0]
            rews = [x[6] for x in traj]
            advs, G = [], 0.0
            for t in reversed(range(len(rews))):
                delta = rews[t] + args.gamma * vals[t+1] - vals[t]
                G = delta + args.gamma * args.gae_lambda * G
                advs.append(G)
            advs = list(reversed(advs))
            rets = [advs[t] + vals[t] for t in range(len(rews))]

            for (o,m,z,a,lp,_v,_r), R_t, Adv in zip(traj, rets, advs):
                O.append(o); MASKS.append(m); Zs.append(z); A.append(a); LP.append(lp); R.append(R_t); ADV.append(Adv)
            steps += len(traj)

        env.close()

    O = torch.tensor(np.array(O), dtype=torch.float32, device=device)
    Zs= torch.tensor(np.array(Zs), dtype=torch.float32, device=device)
    A = torch.tensor(np.array(A), dtype=torch.long, device=device)
    LP= torch.tensor(np.array(LP), dtype=torch.float32, device=device)
    R = torch.tensor(np.array(R), dtype=torch.float32, device=device)
    ADV = torch.tensor(np.array(ADV), dtype=torch.float32, device=device)
    ADV = (ADV - ADV.mean()) / (ADV.std() + 1e-8)
    MASKS = torch.tensor(np.array(MASKS), dtype=torch.float32, device=device)
    return (O, Zs, A, LP, R, ADV, MASKS)

def ppo_update(p, batch):
    O, Zs, A, LP_old, R_t, ADV, MASKS = batch
    N = O.shape[0]; idx = np.arange(N)
    for _ in range(args.ppo_epochs):
        np.random.shuffle(idx)
        for j in range(0, N, args.ppo_batch):
            jj = idx[j:j+args.ppo_batch]
            obs_t = O[jj]; z_t = Zs[jj]; act_t = A[jj]; mask_t = MASKS[jj]
            logp, ent, val = agents[p].evaluate(obs_t, z_t, act_t, mask_t=mask_t)
            ratio = torch.exp(logp - LP_old[jj])
            s1 = ratio * ADV[jj]
            s2 = torch.clamp(ratio, 1.0-args.clip, 1.0+args.clip) * ADV[jj]
            loss = -torch.min(s1,s2).mean() - ENT_BETA * ent.mean() + 0.5 * (R_t[jj]-val).pow(2).mean()
            agents[p].opt.zero_grad(set_to_none=True); loss.backward(); agents[p].opt.step()

os.makedirs(os.path.dirname(args.csv) or ".", exist_ok=True)
os.makedirs(os.path.dirname(args.video) or ".", exist_ok=True)

with open(args.csv, "w", newline="") as f:
    w = csv.writer(f)
    header = ["iter","timestamp","time_sec","mem_mb","mem_type",
              "ret_white","ret_black","ret_mean","ret_sum","pop_sizes","video_path"]
    w.writerow(header)

    for it in range(1, GEMS_ITERS + 1):
        t0 = time.time()

        vhat, rbar = meta_estimate(it)

        mwu_update_omwu(vhat, rbar, it)

        for p in range(N_AGENTS):
            oracle_select(p, it)

        for p in range(N_AGENTS):
            batch = collect_rollouts(p, Z[p][-1])
            ppo_update(p, batch)

        prof = [len(Z[p]) - 1 for p in range(N_AGENTS)]
        s_eval = viz_seed(it)
        rets, _ = run_episode(prof, [Z[q][prof[q]] for q in range(N_AGENTS)], render=False, seed=s_eval)
        dt = time.time() - t0
        mem, mtype = _mem_mb()
        ret_mean = float(np.mean(rets)); ret_sum = float(np.sum(rets))
        pop_sizes = str([len(Z[p]) for p in range(N_AGENTS)])

        print(f"[GEMS] iter {it}/{GEMS_ITERS} time={dt:.2f}s mean={ret_mean:.2f} sum={ret_sum:.2f} pop={pop_sizes}")

        row = [it, time.strftime("%Y-%m-%d %H:%M:%S"), f"{dt:.3f}", f"{mem:.2f}", mtype,
               f"{rets[0]:.3f}", f"{rets[1]:.3f}", f"{ret_mean:.3f}", f"{ret_sum:.3f}", pop_sizes, ""]
        w.writerow(row); f.flush()

print("[GEMS] recording last iteration...")
seed_for_record = viz_seed(GEMS_ITERS)
prof = [len(Z[p]) - 1 for p in range(N_AGENTS)]
zov = [Z[q][prof[q]] for q in range(N_AGENTS)]
rets, vpath = record_episode(prof, zov, seed_for_record, args.video, args.fps)

with open(args.csv, "a", newline="") as f2:
    w2 = csv.writer(f2)
    mem, mtype = _mem_mb()
    ret_mean = float(np.mean(rets)); ret_sum = float(np.sum(rets))
    row = [GEMS_ITERS, time.strftime("%Y-%m-%d %H:%M:%S"), "0.000", f"{mem:.2f}", mtype,
           f"{rets[0]:.3f}", f"{rets[1]:.3f}", f"{ret_mean:.3f}", f"{ret_sum:.3f}",
           str([len(Z[p]) for p in range(N_AGENTS)]), vpath or ""]
    w2.writerow(row)

print(f"[GEMS] saved video at: {vpath}")
print("done")