import argparse, time, csv, os, random
import numpy as np
import torch, torch.nn as nn
import imageio.v2 as imageio
from gymnasium.spaces import Discrete, Box

# ------------------------------------------------------------------------------------
# Env selection
# ------------------------------------------------------------------------------------
def _load_env(env_name):
    assert env_name in ["simple_spread_v3", "simple_tag_v3"]
    if env_name == "simple_spread_v3":
        try:
            from pettingzoo.mpe2 import simple_spread_v3 as simple_spread_v3
        except Exception:
            from pettingzoo.mpe import simple_spread_v3 as simple_spread_v3
        return simple_spread_v3
    else:
        try:
            from pettingzoo.mpe2 import simple_tag_v3 as simple_tag_v3
        except Exception:
            from pettingzoo.mpe import simple_tag_v3 as simple_tag_v3
        return simple_tag_v3

# ------------------------------------------------------------------------------------
# Determinism + init (make identical across scripts)
# ------------------------------------------------------------------------------------
def _seed_everything(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def _init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight, gain=1.0)
        nn.init.zeros_(m.bias)

def viz_seed(it: int):
    # Same formula in both scripts
    return (SEED * 9973 + it * 7919) & 0x7fffffff

# ------------------------------------------------------------------------------------
# Memory helper
# ------------------------------------------------------------------------------------
try:
    import psutil
    def _ram_mb():
        return psutil.Process().memory_info().rss / (1024**2)
except Exception:
    try:
        import resource, sys
        if sys.platform == "darwin":
            def _ram_mb():
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024**2)
        else:
            def _ram_mb():
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0
    except Exception:
        def _ram_mb():
            return float("nan")

def _mem_mb():
    try:
        return float(_ram_mb()), "rss"
    except Exception:
        return float("nan"), "n/a"

# ------------------------------------------------------------------------------------
# Args
# ------------------------------------------------------------------------------------
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--env", type=str, default="simple_spread_v3",
                   choices=["simple_spread_v3", "simple_tag_v3"])
    p.add_argument("--iters", type=int, default=100)
    p.add_argument("--agents", type=int, default=3,
                   help="simple_spread: #agents (homogeneous). Ignored for simple_tag.")
    p.add_argument("--seed", type=int, default=0)

    # Episode/time controls
    p.add_argument("--max_cycles", type=int, default=50,
                   help="steps per episode (50-100 recommended)")
    p.add_argument("--eval_episodes", type=int, default=3,
                   help="avg eval over K episodes per iteration")

    # PPO / BR training budget
    p.add_argument("--rollout_min_steps", type=int, default=600)
    p.add_argument("--ppo_epochs", type=int, default=2)
    p.add_argument("--ppo_batch", type=int, default=256)
    p.add_argument("--gamma", type=float, default=0.99)
    p.add_argument("--gae_lambda", type=float, default=0.95)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--clip", type=float, default=0.2)
    p.add_argument("--ent_beta", type=float, default=2e-3)

    # Meta-solver (MWU over sampled payoffs)
    p.add_argument("--eta", type=float, default=0.25, help="MWU learning rate")
    p.add_argument("--mc_B", type=int, default=2, help="meta-estimation outer samples")
    p.add_argument("--mc_m", type=int, default=1, help="meta-estimation inner samples per profile")

    # PSRO oracle cadence and BR init
    p.add_argument("--oracle_every", type=int, default=1, help="run BR/oracle every K iterations (1 = every iter)")
    p.add_argument("--br_init", type=str, default="random", choices=["random","clone_last"],
                   help="initialize BR net randomly or by cloning last policy")

    # Logging / device
    p.add_argument("--csv", type=str, default="psro_results.csv")
    p.add_argument("--video", type=str, default="psro_last.gif")  # keep name; writer forces GIF
    p.add_argument("--fps", type=int, default=30)
    p.add_argument("--device", type=str, default="auto", choices=["auto","cuda","cpu"])
    p.add_argument("--continuous_actions", action="store_true",
                   help="Use continuous actions if env supports it")

    # simple_tag knobs
    p.add_argument("--tag_adversaries", type=int, default=3, help="#taggers (adversaries)")
    p.add_argument("--tag_runners", type=int, default=1, help="#runners (good agents)")
    p.add_argument("--tag_obstacles", type=int, default=2, help="#obstacles")
    return p.parse_args()

args = parse_args()

# Headless robustness
if "DISPLAY" not in os.environ:
    os.environ.setdefault("SDL_VIDEODRIVER", "dummy")

# Device
if args.device == "auto":
    dev = "cuda" if torch.cuda.is_available() else "cpu"
else:
    dev = args.device
device = torch.device(dev)

# Seeds
SEED = args.seed
_seed_everything(SEED)

# ------------------------------------------------------------------------------------
# Env factories
# ------------------------------------------------------------------------------------
EnvCls = _load_env(args.env)

def make_env(render=False, mode=None):
    render_mode = mode if render else None
    if args.env == "simple_spread_v3":
        env = EnvCls.parallel_env(
            N=args.agents,
            max_cycles=args.max_cycles,
            continuous_actions=args.continuous_actions,
            render_mode=render_mode
        )
    else:
        env = EnvCls.parallel_env(
            num_good=args.tag_runners,
            num_adversaries=args.tag_adversaries,
            num_obstacles=args.tag_obstacles,
            max_cycles=args.max_cycles,
            continuous_actions=args.continuous_actions,
            render_mode=render_mode
        )
    return env

# Probe env
_probe_env = make_env(False, None)
_init_obs, _ = _probe_env.reset(seed=SEED)
AGENT_IDS = list(_probe_env.agents)
_obs_dims = {aid: _probe_env.observation_space(aid).shape[0] for aid in AGENT_IDS}
_act_spaces = {aid: _probe_env.action_space(aid) for aid in AGENT_IDS}
_is_all_discrete = all(isinstance(_act_spaces[aid], Discrete) for aid in AGENT_IDS)
if not _is_all_discrete:
    assert all(isinstance(_act_spaces[aid], Box) for aid in AGENT_IDS), "Mixed action spaces not supported"

# Team index maps (simple_tag)
GOOD_IDX = [i for i,a in enumerate(AGENT_IDS) if a.startswith("agent_")]
BAD_IDX  = [i for i,a in enumerate(AGENT_IDS) if a.startswith("adversary_")]

N_AGENTS = len(AGENT_IDS)

# ------------------------------------------------------------------------------------
# Utils
# ------------------------------------------------------------------------------------
def write_video(frames, path, fps):
    if not frames:
        return None
    # Always write GIF (standardized across scripts)
    gif_path = os.path.splitext(path)[0] + ".gif"
    imageio.mimsave(gif_path, frames, duration=1.0/max(fps,1))
    return gif_path

# ------------------------------------------------------------------------------------
# Networks & Agent wrappers
# ------------------------------------------------------------------------------------
class CategoricalHead(nn.Module):
    def __init__(self, in_dim, act_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.ReLU(),
            nn.Linear(32, act_dim)
        )
        self.apply(_init_weights)
    def forward(self, x):
        return self.net(x)

class GaussianHead(nn.Module):
    def __init__(self, in_dim, act_dim):
        super().__init__()
        self.mu = nn.Sequential(
            nn.Linear(in_dim, 64), nn.ReLU(),
            nn.Linear(64, act_dim)
        )
        self.logstd = nn.Parameter(torch.zeros(act_dim))
        self.apply(_init_weights)
    def forward(self, x):
        mu = self.mu(x)
        return mu, self.logstd.expand_as(mu)

class VNet(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.ReLU(),
            nn.Linear(32, 1)
        )
        self.apply(_init_weights)
    def forward(self, x):
        return self.net(x).squeeze(-1)

class AC(nn.Module):
    def __init__(self, obs_dim, act_space):
        super().__init__()
        self.discrete = isinstance(act_space, Discrete)
        self.obs_dim = obs_dim
        if self.discrete:
            self.pi = CategoricalHead(obs_dim, act_space.n)
        else:
            self.pi = GaussianHead(obs_dim, act_space.shape[0])
        self.v  = VNet(obs_dim)
        self.apply(_init_weights)
    def forward(self, obs):
        if self.discrete:
            logits = self.pi(obs)
            return logits, self.v(obs)
        else:
            mu, logstd = self.pi(obs)
            return (mu, logstd), self.v(obs)

class Policy:
    """Standalone policy (no latent)."""
    def __init__(self, obs_dim, act_space):
        self.discrete = isinstance(act_space, Discrete)
        self.act_space = act_space
        self.net = AC(obs_dim, act_space).to(device)
        self.opt = torch.optim.Adam(self.net.parameters(), lr=args.lr)

    def clone_from(self, other):
        self.net.load_state_dict(other.net.state_dict())

    @torch.no_grad()
    def act(self, obs_np):
        obs = torch.tensor(obs_np, dtype=torch.float32, device=device).unsqueeze(0)
        pi_out, v = self.net(obs)
        if self.discrete:
            d = torch.distributions.Categorical(logits=pi_out)
            a = d.sample()
            return a.item(), d.log_prob(a).squeeze(0), d.entropy().squeeze(0), v.squeeze(0)
        else:
            mu, logstd = pi_out; std = logstd.exp()
            d = torch.distributions.Independent(torch.distributions.Normal(mu, std), 1)
            a = d.sample()
            a_np = a.squeeze(0).cpu().numpy()
            if isinstance(self.act_space, Box):
                a_np = np.clip(a_np, self.act_space.low, self.act_space.high)
            return a_np, d.log_prob(a).squeeze(0), d.entropy().squeeze(0), v.squeeze(0)

    def evaluate(self, obs_t, act_t):
        pi_out, v = self.net(obs_t)
        if self.discrete:
            d = torch.distributions.Categorical(logits=pi_out)
        else:
            mu, logstd = pi_out; std = logstd.exp()
            d = torch.distributions.Independent(torch.distributions.Normal(mu, std), 1)
        logp = d.log_prob(act_t)
        ent = d.entropy()
        return logp, ent, v

# ------------------------------------------------------------------------------------
# PSRO state
# ------------------------------------------------------------------------------------
ETA = args.eta; GAMMA = args.gamma; LAMBDA = args.gae_lambda
PPO_EPOCHS, PPO_BATCH = args.ppo_epochs, args.ppo_batch
MIN_STEPS_PER_BR = args.rollout_min_steps
GEMS_ITERS = args.iters
MC_B, MC_m = args.mc_B, args.mc_m
ENT_BETA = args.ent_beta

# Pool of policies per agent; meta σ per agent
Pools = [[] for _ in range(N_AGENTS)]
Sigma = []

def init_psro():
    for i, aid in enumerate(AGENT_IDS):
        pol = Policy(_obs_dims[aid], _act_spaces[aid])
        Pools[i].append(pol)
        Sigma.append([1.0])

init_psro()

# ------------------------------------------------------------------------------------
# Mixture sampling helpers
# ------------------------------------------------------------------------------------
def sample_profile():
    prof = []
    for p in range(N_AGENTS):
        probs = np.array(Sigma[p], dtype=np.float64)
        probs = probs / probs.sum()
        idx = np.random.choice(len(probs), p=probs)
        prof.append(idx)
    return prof

def act_with_profile(obs_dict, prof):
    acts = {}
    for aid in obs_dict.keys():
        p = AGENT_IDS.index(aid)
        a,_,_,_ = Pools[p][prof[p]].act(obs_dict[aid])
        acts[aid] = a
    return acts

# ------------------------------------------------------------------------------------
# Episodes / Evaluation
# ------------------------------------------------------------------------------------
def run_episode(prof=None, override=None, render=False, seed=None, record_rgb=False):
    """If override is a dict {p: Policy}, use it for those players."""
    mode = "rgb_array" if (render or record_rgb) else None
    env = make_env(render=(render or record_rgb), mode=mode)

    obs, _ = env.reset(seed=seed if seed is not None else random.randint(0, 1<<30))
    frames = []
    rets = np.zeros(N_AGENTS, dtype=np.float32)
    done = False

    if record_rgb:
        frame = env.render()
        if frame is not None: frames.append(frame)

    while env.agents and not done:
        acts = {}
        for aid in env.agents:
            p = AGENT_IDS.index(aid)
            if (override is not None) and (p in override):
                a,_,_,_ = override[p].act(obs[aid])
            else:
                k = prof[p] if prof is not None else 0
                a,_,_,_ = Pools[p][k].act(obs[aid])
            acts[aid] = a

        obs, r, term, trunc, _ = env.step(acts)

        if record_rgb:
            frame = env.render()
            if frame is not None: frames.append(frame)

        for aid, v in r.items():
            p = AGENT_IDS.index(aid)
            rets[p] += float(v)
        done = all(term.values()) or all(trunc.values())

    env.close()
    return rets, frames

def record_episode(prof, seed, path, fps):
    rets, frames = run_episode(prof=prof, render=True, seed=seed, record_rgb=True)
    out = write_video(frames, path, fps)
    return rets, out

# ------------------------------------------------------------------------------------
# Meta-estimation (estimate values for each policy against opponent mixtures)
# ------------------------------------------------------------------------------------
def meta_estimate():
    vhat = [np.zeros(len(Sigma[p]), dtype=np.float32) for p in range(N_AGENTS)]
    rbar = np.zeros(N_AGENTS, dtype=np.float32)
    for _ in range(MC_B):
        base_prof = sample_profile()
        for p in range(N_AGENTS):
            for k in range(len(Pools[p])):
                prof = base_prof.copy()
                prof[p] = k
                s = 0.0
                for __ in range(MC_m):
                    rets, _ = run_episode(prof)
                    s += rets[p]
                vhat[p][k] += s / max(MC_m, 1)
    for _ in range(MC_B * MC_m):
        prof = sample_profile()
        rets, _ = run_episode(prof)
        rbar += rets
    rbar /= max(MC_B * MC_m, 1)
    for p in range(N_AGENTS):
        vhat[p] /= max(MC_B, 1)
    return vhat, rbar

def mwu_update(vhat, rbar):
    for p in range(N_AGENTS):
        s = np.array(Sigma[p], dtype=np.float64)
        gains = vhat[p] - rbar[p]
        s = s * np.exp(ETA * gains)
        s = s / s.sum()
        Sigma[p] = s.tolist()

# ------------------------------------------------------------------------------------
# PPO BR training (train a fresh policy for agent p vs opponents' σ)
# ------------------------------------------------------------------------------------
def collect_rollouts_BR(p, pol_p):
    O, A, LP, R, ADV = [], [], [], [], []
    steps = 0
    env = make_env(False, None)
    try:
        while steps < MIN_STEPS_PER_BR:
            obs, _ = env.reset(seed=random.randint(0, 1<<30))
            traj = []
            done = False
            prof = sample_profile()
            while env.agents and not done:
                acts = {}
                for aid in env.agents:
                    i = AGENT_IDS.index(aid)
                    if i == p:
                        a, lp, _, v = pol_p.act(obs[aid])
                        acts[aid] = a
                        traj.append([obs[aid], a, lp.item(), v.item(), 0.0])
                    else:
                        k = prof[i]
                        a,_,_,_ = Pools[i][k].act(obs[aid])
                        acts[aid] = a
                obs, r, term, trunc, _ = env.step(acts)
                if traj:
                    traj[-1][4] = float(r[AGENT_IDS[p]])
                done = all(term.values()) or all(trunc.values())

            vals = [x[3] for x in traj] + [0.0]
            rews = [x[4] for x in traj]
            advs, G = [], 0.0
            for t in reversed(range(len(rews))):
                delta = rews[t] + GAMMA * vals[t+1] - vals[t]
                G = delta + GAMMA * LAMBDA * G
                advs.append(G)
            advs = list(reversed(advs))
            rets = [advs[t] + vals[t] for t in range(len(rews))]

            for (o,a,lp,_v,_r), R_t, Adv in zip(traj, rets, advs):
                O.append(o); A.append(a); LP.append(lp); R.append(R_t); ADV.append(Adv)
            steps += len(traj)
    finally:
        env.close()

    O  = torch.tensor(np.array(O), dtype=torch.float32, device=device)
    if Pools[p][0].discrete:
        A = torch.tensor(np.array(A), dtype=torch.long, device=device)
    else:
        A = torch.tensor(np.array(A), dtype=torch.float32, device=device)
    LP = torch.tensor(np.array(LP), dtype=torch.float32, device=device)
    R  = torch.tensor(np.array(R), dtype=torch.float32, device=device)
    ADV= torch.tensor(np.array(ADV), dtype=torch.float32, device=device)
    ADV= (ADV - ADV.mean()) / (ADV.std() + 1e-8)
    return (O, A, LP, R, ADV)

def ppo_update_BR(p, pol_p, batch):
    O, A, LP_old, R_t, ADV = batch
    N = O.shape[0]
    idx = np.arange(N)
    for _ in range(PPO_EPOCHS):
        np.random.shuffle(idx)
        for j in range(0, N, PPO_BATCH):
            jj = idx[j:j+PPO_BATCH]
            obs_t = O[jj]; act_t = A[jj]
            logp, ent, val = pol_p.evaluate(obs_t, act_t)
            ratio = torch.exp(logp - LP_old[jj])
            s1 = ratio * ADV[jj]
            s2 = torch.clamp(ratio, 1.0-args.clip, 1.0+args.clip) * ADV[jj]
            loss = -torch.min(s1,s2).mean() - args.ent_beta * ent.mean() + 0.5 * (R_t[jj]-val).pow(2).mean()
            pol_p.opt.zero_grad(set_to_none=True); loss.backward(); pol_p.opt.step()

def make_new_policy(p):
    aid = AGENT_IDS[p]
    pol = Policy(_obs_dims[aid], _act_spaces[aid])
    if args.br_init == "clone_last" and len(Pools[p]) > 0:
        pol.clone_from(Pools[p][-1])
    return pol

# ------------------------------------------------------------------------------------
# Eval helpers
# ------------------------------------------------------------------------------------
def eval_latest_profile(K):
    prof = [len(Pools[i])-1 for i in range(N_AGENTS)]
    agg = []
    for k in range(K):
        s = viz_seed(100000 + k)
        rets, _ = run_episode(prof=prof, seed=s)
        agg.append(rets)
    return np.mean(np.stack(agg, axis=0), axis=0)

# ------------------------------------------------------------------------------------
# Training loop
# ------------------------------------------------------------------------------------
os.makedirs(os.path.dirname(args.csv) or ".", exist_ok=True)
os.makedirs(os.path.dirname(args.video) or ".", exist_ok=True)

print(f"[PSRO] env={args.env} agents={N_AGENTS} device={device.type}" +
      (f" gpu={torch.cuda.get_device_name(0)}" if device.type=='cuda' else ""))

with open(args.csv, "w", newline="") as f:
    w = csv.writer(f)

    base_header = ["iter","timestamp","time_sec","mem_mb","mem_type"] + \
                  [f"ret_{i}" for i in range(N_AGENTS)] + \
                  ["ret_mean","ret_sum","pool_sizes","video_path"]

    if args.env == "simple_tag_v3":
        header = base_header[:-2] + ["good_avg","bad_avg","good_sum","bad_sum"] + base_header[-2:]
    else:
        header = base_header
    w.writerow(header)

    for it in range(GEMS_ITERS):
        t0 = time.time()

        # 1) Meta-estimate & update meta-policy (MWU)
        vhat, rbar = meta_estimate()
        mwu_update(vhat, rbar)

        # 2) PSRO oracle (BR)
        if (it % args.oracle_every) == 0:
            for p in range(N_AGENTS):
                pol = make_new_policy(p)
                batch = collect_rollouts_BR(p, pol)
                ppo_update_BR(p, pol, batch)
                Pools[p].append(pol)
                s = np.array(Sigma[p] + [1e-3], dtype=np.float64)
                s /= s.sum()
                Sigma[p] = s.tolist()

        # 3) Evaluation using latest policy per agent
        rets = eval_latest_profile(args.eval_episodes)

        dt = time.time() - t0
        mem, mtype = _mem_mb()
        overall_mean = float(np.mean(rets))
        overall_sum  = float(np.sum(rets))

        if args.env == "simple_tag_v3":
            good_avg = float(np.mean(rets[GOOD_IDX])) if GOOD_IDX else float('nan')
            bad_avg  = float(np.mean(rets[BAD_IDX]))  if BAD_IDX  else float('nan')
            good_sum = float(np.sum(rets[GOOD_IDX]))  if GOOD_IDX else float('nan')
            bad_sum  = float(np.sum(rets[BAD_IDX]))   if BAD_IDX  else float('nan')

            print(f"[PSRO] iter {it+1}/{GEMS_ITERS} time={dt:.2f}s "
                  f"good_avg={good_avg:.2f} bad_avg={bad_avg:.2f} "
                  f"(overall mean={overall_mean:.2f} sum={overall_sum:.2f}) "
                  f"pool={ [len(Pools[p]) for p in range(N_AGENTS)] }")

            row = [it+1, time.strftime("%Y-%m-%d %H:%M:%S"), f"{dt:.3f}", f"{mem:.2f}", mtype] + \
                  [f"{r:.3f}" for r in rets.tolist()] + \
                  [f"{overall_mean:.3f}", f"{overall_sum:.3f}",
                   f"{good_avg:.3f}", f"{bad_avg:.3f}", f"{good_sum:.3f}", f"{bad_sum:.3f}",
                   str([len(Pools[p]) for p in range(N_AGENTS)]), ""]
        else:
            print(f"[PSRO] iter {it+1}/{GEMS_ITERS} time={dt:.2f}s "
                  f"mean={overall_mean:.2f} sum={overall_sum:.2f} "
                  f"pool={ [len(Pools[p]) for p in range(N_AGENTS)] }")

            row = [it+1, time.strftime("%Y-%m-%d %H:%M:%S"), f"{dt:.3f}", f"{mem:.2f}", mtype] + \
                  [f"{r:.3f}" for r in rets.tolist()] + \
                  [f"{overall_mean:.3f}", f"{overall_sum:.3f}",
                   str([len(Pools[p]) for p in range(N_AGENTS)]), ""]
        w.writerow(row); f.flush()

# 4) Record last episode with standardized seed
print("[PSRO] recording last iteration...")
seed_for_record = viz_seed(GEMS_ITERS)
last_prof = [len(Pools[i])-1 for i in range(N_AGENTS)]
rets, vpath = record_episode(last_prof, seed_for_record, args.video, args.fps)
with open(args.csv, "a", newline="") as f2:
    w2 = csv.writer(f2)
    mem, mtype = _mem_mb()
    overall_mean = float(np.mean(rets)); overall_sum = float(np.sum(rets))
    if args.env == "simple_tag_v3":
        good_avg = float(np.mean(rets[GOOD_IDX])) if GOOD_IDX else float('nan')
        bad_avg  = float(np.mean(rets[BAD_IDX]))  if BAD_IDX  else float('nan')
        good_sum = float(np.sum(rets[GOOD_IDX]))  if GOOD_IDX else float('nan')
        bad_sum  = float(np.sum(rets[BAD_IDX]))   if BAD_IDX  else float('nan')
        row = [GEMS_ITERS, time.strftime("%Y-%m-%d %H:%M:%S"),
               f"0.000", f"{mem:.2f}", mtype] + \
              [f"{r:.3f}" for r in rets.tolist()] + \
              [f"{overall_mean:.3f}", f"{overall_sum:.3f}",
               f"{good_avg:.3f}", f"{bad_avg:.3f}", f"{good_sum:.3f}", f"{bad_sum:.3f}",
               str([len(Pools[p]) for p in range(N_AGENTS)]), vpath or ""]
    else:
        row = [GEMS_ITERS, time.strftime("%Y-%m-%d %H:%M:%S"),
               f"0.000", f"{mem:.2f}", mtype] + \
              [f"{r:.3f}" for r in rets.tolist()] + \
              [f"{overall_mean:.3f}", f"{overall_sum:.3f}",
               str([len(Pools[p]) for p in range(N_AGENTS)]), vpath or ""]
    w2.writerow(row)
print(f"[PSRO] saved video at: {vpath}")

print("done")
