import argparse, time, csv, sys
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.distributions import Categorical

# ---------------------------------------------------------------------------
#                           Argument parser
# ---------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="One-step sender/receiver PSRO with PPO oracles (multi-seed avg)"
    )
    # Environment
    p.add_argument("--K", type=int, default=5,
                   help="Number of arms (also size of sender obs)")
    p.add_argument("--M", type=int, default=3,
                   help="Message dimension (size of receiver obs)")
    p.add_argument("--arms-means", type=float, nargs="+",
                   default=[0.2, 0.5, 0.8, 0.4, 0.1],
                   metavar="μ", help="Space-separated arm means")
    p.add_argument("--bad-arm-target", type=int, default=0,
                   help="Arm index sender prefers receiver to pull")
    p.add_argument("--top-seed", type=int, default=0, help="Base seed if --seeds not given")

    # PSRO loop
    p.add_argument("--iters", type=int, default=6,
                   help="Number of PSRO outer iterations")
    p.add_argument("--eval-episodes", type=int, default=400,
                   help="Rollouts per payoff entry")
    p.add_argument("--oracle-epochs", type=int, default=200,
                   help="PPO epochs for each new oracle")
    p.add_argument("--oracle-lr", type=float, default=1e-3,
                   help="Learning rate for PPO oracles")

    # PPO hyper-params
    p.add_argument("--ppo-batch", type=int, default=256)
    p.add_argument("--ppo-epochs", type=int, default=4)
    p.add_argument("--ppo-minibatch", type=int, default=64)
    p.add_argument("--ppo-clip", type=float, default=0.2)
    p.add_argument("--vf-coef", type=float, default=0.5)
    p.add_argument("--ent-coef", type=float, default=0.01)

    # Device & logging
    p.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
    p.add_argument("--log-csv", default=None,
                   help="CSV filename for aggregated (mean/std) log. "
                        "Default: psro_avg_<seedspan>.csv")

    # Multi-seed controls
    p.add_argument("--seeds", type=int, nargs="+", default=None,
                   help="Explicit list of top seeds to average over, e.g. --seeds 0 1 2 3 4")
    p.add_argument("--num-seeds", type=int, default=5,
                   help="If --seeds not given, run seeds from top_seed to top_seed+num_seeds-1")

    # Plotting
    p.add_argument("--no_plot", action="store_true",
                   help="Disable matplotlib plotting")

    return p.parse_args()

# ---------------------------------------------------------------------------
#                           Helper: RAM usage
# ---------------------------------------------------------------------------

try:
    import psutil
    _proc = psutil.Process()
    def get_ram_mb(): return _proc.memory_info().rss / 1_048_576
except Exception:
    try:
        import resource
        if sys.platform == "darwin":
            def get_ram_mb():
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1_048_576
        else:
            def get_ram_mb():
                return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0
    except Exception:
        def get_ram_mb(): return float("nan")

# ---------------------------------------------------------------------------
#                           Networks
# ---------------------------------------------------------------------------

class Sender(nn.Module):
    def __init__(self, K, M):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(K, 32), nn.ReLU(), nn.Linear(32, M))
    def forward(self, x): return self.net(x)

class Receiver(nn.Module):
    def __init__(self, M, K):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(M, 32), nn.ReLU(), nn.Linear(32, K))
    def forward(self, x): return self.net(x)

class SenderPPO(nn.Module):
    def __init__(self, K, M):
        super().__init__()
        self.actor  = nn.Sequential(nn.Linear(K, 32), nn.ReLU(), nn.Linear(32, M))
        self.critic = nn.Sequential(nn.Linear(K, 32), nn.ReLU(), nn.Linear(32, 1))
    def forward(self, x):
        return self.actor(x), self.critic(x).squeeze(-1)

class ReceiverPPO(nn.Module):
    def __init__(self, M, K):
        super().__init__()
        self.actor  = nn.Sequential(nn.Linear(M, 32), nn.ReLU(), nn.Linear(32, K))
        self.critic = nn.Sequential(nn.Linear(M, 32), nn.ReLU(), nn.Linear(32, 1))
    def forward(self, x):
        return self.actor(x), self.critic(x).squeeze(-1)

# ---------------------------------------------------------------------------
#                           Utilities
# ---------------------------------------------------------------------------

def logits_from_agent(agent, obs):
    out = agent(obs.unsqueeze(0) if obs.dim()==1 else obs)
    return out[0].squeeze(0) if isinstance(out, (tuple,list)) else out.squeeze(0)

def set_seed(s): np.random.seed(s); torch.manual_seed(s)

# ---------------------------------------------------------------------------
#                           One-step episode
# ---------------------------------------------------------------------------

def sample_episode(sender, receiver, K, M, BEST_ARM, BAD_ARM_TARGET,
                   ARMS_MEANS, device):
    s_obs = torch.zeros(K, device=device); s_obs[BEST_ARM] = 1.0
    msg = Categorical(logits=logits_from_agent(sender, s_obs)).sample()
    r_obs = torch.zeros(M, device=device); r_obs[msg] = 1.0
    act = Categorical(logits=logits_from_agent(receiver, r_obs)).sample()

    chosen = int(act)
    r_reward = float(np.random.rand() < ARMS_MEANS[chosen])
    s_reward = float(chosen == BAD_ARM_TARGET)
    return s_reward, r_reward

# ---------------------------------------------------------------------------
#                           Payoff evaluation
# ---------------------------------------------------------------------------

def evaluate_payoffs(senders, receivers, n_eval, seed,
                     K, M, BEST_ARM, BAD_ARM_TARGET, ARMS_MEANS, device):
    if seed is not None: set_seed(seed)
    S, R = len(senders), len(receivers)
    sp, rp = np.zeros((S,R)), np.zeros((S,R))
    for i in range(S):
        for j in range(R):
            ss=rr=0.0
            for _ in range(n_eval):
                s_r, r_r = sample_episode(senders[i], receivers[j],
                                          K,M,BEST_ARM,BAD_ARM_TARGET,
                                          ARMS_MEANS, device)
                ss+=s_r; rr+=r_r
            sp[i,j]=ss/n_eval; rp[i,j]=rr/n_eval
    return sp, rp

def fictitious_play(sp, rp, iters=200):
    S,R = sp.shape
    s_cnt, r_cnt = np.full(S,1e-6), np.full(R,1e-6)
    s_mix, r_mix = np.ones(S)/S, np.ones(R)/R
    for _ in range(iters):
        s_cnt[np.argmax(sp.dot(r_mix))]+=1; s_mix=s_cnt/s_cnt.sum()
        r_cnt[np.argmax(rp.T.dot(s_mix))]+=1; r_mix=r_cnt/r_cnt.sum()
    return s_mix, r_mix

# ---------------------------------------------------------------------------
#                           PPO trainers
# ---------------------------------------------------------------------------

def train_sender_ppo(r_mix, R_pop, args, device, seed_base):
    r_mix = np.asarray(r_mix); r_mix = r_mix/r_mix.sum() if r_mix.sum()>0 else \
            np.ones(len(R_pop))/len(R_pop)
    model=SenderPPO(args.K,args.M).to(device); opt=optim.Adam(model.parameters(),lr=args.oracle_lr)

    for ep in range(1,args.oracle_epochs+1):
        set_seed(seed_base+ep)
        states,acts,old_logps,rets=[],[],[],[]
        for _ in range(args.ppo_batch):
            r_idx=np.random.choice(len(R_pop),p=r_mix)
            recv=R_pop[r_idx]

            s_obs=torch.zeros(args.K,device=device); s_obs[args.BEST_ARM]=1.0
            logits,val=model(s_obs.unsqueeze(0)); dist=Categorical(logits=logits.squeeze(0))
            msg=dist.sample(); logp=dist.log_prob(msg)

            r_obs=torch.zeros(args.M,device=device); r_obs[msg]=1.0
            act=Categorical(logits=logits_from_agent(recv,r_obs)).sample()

            chosen=int(act); r_reward=float(np.random.rand()<args.ARMS_MEANS[chosen])
            s_reward=float(chosen==args.bad_arm_target)

            states.append(s_obs.cpu().numpy()); acts.append(int(msg)); old_logps.append(float(logp))
            rets.append(s_reward)
        # PPO update --------------------------------------------------------
        batch=torch.tensor(np.stack(states),dtype=torch.float32,device=device)
        acts_t=torch.tensor(acts,dtype=torch.int64,device=device)
        old_lp=torch.tensor(old_logps,dtype=torch.float32,device=device)
        returns=torch.tensor(rets,dtype=torch.float32,device=device)

        for _ in range(args.ppo_epochs):
            idx=np.random.permutation(args.ppo_batch)
            for start in range(0,args.ppo_batch,args.ppo_minibatch):
                mb=idx[start:start+args.ppo_minibatch]
                logits,val=model(batch[mb]); d=Categorical(logits=logits)
                new_lp=d.log_prob(acts_t[mb]); ent=d.entropy().mean()
                adv=returns[mb]-val.detach()
                ratio=torch.exp(new_lp-old_lp[mb])
                pol_loss=-torch.min(ratio*adv,
                                    torch.clamp(ratio,1-args.ppo_clip,1+args.ppo_clip)*adv).mean()
                v_loss=((val-returns[mb])**2).mean()
                loss=pol_loss+args.vf_coef*v_loss-args.ent_coef*ent
                opt.zero_grad(); loss.backward(); opt.step()
    return model

def train_receiver_ppo(s_mix, S_pop, args, device, seed_base):
    s_mix=np.asarray(s_mix); s_mix=s_mix/s_mix.sum() if s_mix.sum()>0 else \
          np.ones(len(S_pop))/len(S_pop)
    model=ReceiverPPO(args.M,args.K).to(device); opt=optim.Adam(model.parameters(),lr=args.oracle_lr)

    for ep in range(1,args.oracle_epochs+1):
        set_seed(seed_base+ep)
        states,acts,old_lp,rets=[],[],[],[]
        for _ in range(args.ppo_batch):
            s_idx=np.random.choice(len(S_pop),p=s_mix)
            send=S_pop[s_idx]

            s_obs=torch.zeros(args.K,device=device); s_obs[args.BEST_ARM]=1.0
            msg=Categorical(logits=logits_from_agent(send,s_obs)).sample()

            r_obs=torch.zeros(args.M,device=device); r_obs[msg]=1.0
            logits,val=model(r_obs.unsqueeze(0)); dist=Categorical(logits=logits.squeeze(0))
            act=dist.sample(); logp=dist.log_prob(act)

            chosen=int(act); r_reward=float(np.random.rand()<args.ARMS_MEANS[chosen])
            states.append(r_obs.cpu().numpy()); acts.append(int(act)); old_lp.append(float(logp))
            rets.append(r_reward)
        batch=torch.tensor(np.stack(states),dtype=torch.float32,device=device)
        acts_t=torch.tensor(acts,dtype=torch.int64,device=device)
        old_lp_t=torch.tensor(old_lp,dtype=torch.float32,device=device)
        returns=torch.tensor(rets,dtype=torch.float32,device=device)

        for _ in range(args.ppo_epochs):
            idx=np.random.permutation(args.ppo_batch)
            for start in range(0,args.ppo_batch,args.ppo_minibatch):
                mb=idx[start:start+args.ppo_minibatch]
                logits,val=model(batch[mb]); d=Categorical(logits=logits)
                new_lp=d.log_prob(acts_t[mb]); ent=d.entropy().mean()
                adv=returns[mb]-val.detach()
                ratio=torch.exp(new_lp-old_lp_t[mb])
                pol_loss=-torch.min(ratio*adv,
                                    torch.clamp(ratio,1-args.ppo_clip,1+args.ppo_clip)*adv).mean()
                v_loss=((val-returns[mb])**2).mean()
                loss=pol_loss+args.vf_coef*v_loss-args.ent_coef*ent
                opt.zero_grad(); loss.backward(); opt.step()
    return model

# ---------------------------------------------------------------------------
#                           Single-seed PSRO run
# ---------------------------------------------------------------------------

def run_single_seed(args, base_seed: int, device: torch.device):
    """
    Runs the original PSRO loop for a single top seed and returns per-iteration histories.
    """
    # Make args values accessible to helpers
    args.ARMS_MEANS = np.array(args.arms_means, dtype=np.float32)
    args.BEST_ARM = int(np.argmax(args.ARMS_MEANS))

    # Seed model inits etc. so each run genuinely differs across seeds
    set_seed(base_seed)

    sender_pop=[Sender(args.K,args.M).to(device)]
    receiver_pop=[Receiver(args.M,args.K).to(device)]

    hist_s_mean, hist_r_mean = [], []
    it_time, it_ram, it_ts, it_s_pop, it_r_pop = [], [], [], [], []

    for it in range(1, args.psro_iters+1):
        t0=time.time(); it_ts.append(time.strftime("%Y-%m-%d %H:%M:%S"))

        sp,rp=evaluate_payoffs(sender_pop,receiver_pop,args.eval_episodes,
                               seed=base_seed+it, K=args.K,M=args.M,
                               BEST_ARM=args.BEST_ARM,BAD_ARM_TARGET=args.bad_arm_target,
                               ARMS_MEANS=args.ARMS_MEANS,device=device)
        s_mix,r_mix=fictitious_play(sp,rp)
        new_sender=train_sender_ppo(r_mix,receiver_pop,args,device,seed_base=1000*it + base_seed)
        new_receiver=train_receiver_ppo(s_mix,sender_pop,args,device,seed_base=2000*it + base_seed)
        sender_pop.append(new_sender); receiver_pop.append(new_receiver)

        sp2,rp2=evaluate_payoffs(sender_pop,receiver_pop,args.eval_episodes,
                                 seed=base_seed+500+it, K=args.K,M=args.M,
                                 BEST_ARM=args.BEST_ARM,BAD_ARM_TARGET=args.bad_arm_target,
                                 ARMS_MEANS=args.ARMS_MEANS,device=device)
        hist_s_mean.append(sp2.mean()); hist_r_mean.append(rp2.mean())

        it_time.append(time.time()-t0); it_ram.append(get_ram_mb())
        it_s_pop.append(len(sender_pop)); it_r_pop.append(len(receiver_pop))

        print(f"[seed {base_seed}] iter {it:02d} | sender {hist_s_mean[-1]:.4f} "
              f"receiver {hist_r_mean[-1]:.4f} | S={it_s_pop[-1]} R={it_r_pop[-1]}")

    return {
        "sender_mean": np.array(hist_s_mean, dtype=np.float64),
        "receiver_mean": np.array(hist_r_mean, dtype=np.float64),
        "time_sec": np.array(it_time, dtype=np.float64),
        "ram_mb": np.array(it_ram, dtype=np.float64),
        "timestamps": it_ts,
        "sender_pop": np.array(it_s_pop, dtype=np.int64),
        "receiver_pop": np.array(it_r_pop, dtype=np.int64),
    }

# ---------------------------------------------------------------------------
#                           Main
# ---------------------------------------------------------------------------

def main():
    args = parse_args()
    device = torch.device(args.device)

    # Determine which seeds to run
    if args.seeds is not None and len(args.seeds) > 0:
        seeds = list(args.seeds)
    else:
        seeds = list(range(args.top_seed, args.top_seed + args.num_seeds))
    if len(seeds) == 0:
        raise ValueError("No seeds specified.")

    # Run all seeds
    all_sender = []
    all_receiver = []
    all_time = []
    all_ram = []

    # Also track pops to sanity-check alignment
    expected_iters = args.psro_iters
    pop_S = None
    pop_R = None

    for s in seeds:
        res = run_single_seed(args, base_seed=s, device=device)

        if len(res["sender_mean"]) != expected_iters or len(res["receiver_mean"]) != expected_iters:
            raise RuntimeError("Seed run returned unexpected iteration length.")

        all_sender.append(res["sender_mean"])
        all_receiver.append(res["receiver_mean"])
        all_time.append(res["time_sec"])
        all_ram.append(res["ram_mb"])

        if pop_S is None: pop_S = res["sender_pop"]
        if pop_R is None: pop_R = res["receiver_pop"]

    all_sender = np.stack(all_sender, axis=0)   # [n_seeds, iters]
    all_receiver = np.stack(all_receiver, axis=0)
    all_time = np.stack(all_time, axis=0)
    all_ram = np.stack(all_ram, axis=0)

    # Aggregate across seeds per iteration
    sender_mean_avg = all_sender.mean(axis=0)
    sender_mean_std = all_sender.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(sender_mean_avg)
    receiver_mean_avg = all_receiver.mean(axis=0)
    receiver_mean_std = all_receiver.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(receiver_mean_avg)

    time_mean = all_time.mean(axis=0)
    time_std  = all_time.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(time_mean)
    ram_mean  = all_ram.mean(axis=0)
    ram_std   = all_ram.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(ram_mean)

    # ---- plot (optional) ----
    if not args.no_plot:
        # Sender
        plt.figure()
        x = np.arange(1, expected_iters+1)
        plt.plot(x, sender_mean_avg, label="sender (mean)")
        plt.fill_between(x,
                         sender_mean_avg - sender_mean_std,
                         sender_mean_avg + sender_mean_std,
                         alpha=0.2, label="sender (±1σ)")
        # Receiver
        plt.plot(x, receiver_mean_avg, label="receiver (mean)")
        plt.fill_between(x,
                         receiver_mean_avg - receiver_mean_std,
                         receiver_mean_avg + receiver_mean_std,
                         alpha=0.2, label="receiver (±1σ)")
        plt.xlabel("PSRO iteration"); plt.ylabel("avg payoff")
        plt.grid(); plt.legend(); plt.title(f"PSRO payoffs across {len(seeds)} seeds")
        plt.show()

    # ---- csv (aggregated) ----
    if args.log_csv is not None:
        csv_name = args.log_csv
    else:
        if len(seeds) == 1:
            csv_name = f"psro_avg_seed{seeds[0]}.csv"
        else:
            csv_name = f"psro_avg_{min(seeds)}to{max(seeds)}_n{len(seeds)}.csv"

    with open(csv_name, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "iter",
            "sender_mean_avg","sender_mean_std",
            "receiver_mean_avg","receiver_mean_std",
            "time_sec_avg","time_sec_std",
            "ram_mb_avg","ram_mb_std",
            "sender_pop","receiver_pop",
            "seeds"
        ])
        seed_str = ",".join(str(s) for s in seeds)
        for i in range(expected_iters):
            w.writerow([
                i+1,
                float(sender_mean_avg[i]), float(sender_mean_std[i]),
                float(receiver_mean_avg[i]), float(receiver_mean_std[i]),
                float(time_mean[i]), float(time_std[i]),
                float(ram_mean[i]), float(ram_std[i]),
                int(pop_S[i]), int(pop_R[i]),
                seed_str
            ])
    print(f"\nSaved aggregated instrumentation log to {csv_name}")

# ---------------------------------------------------------------------------
#                           Entrypoint
# ---------------------------------------------------------------------------

if __name__=="__main__":
    main()
