import argparse, time, csv, sys
from copy import deepcopy
import numpy as np
import torch, torch.nn as nn, torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt

# ────────────────────────────────────────────────────────────────────────────────
# Argument parser
# ────────────────────────────────────────────────────────────────────────────────
def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Sender/receiver Alpha-PSRO with PPO best-responses (multi-seed avg)"
    )

    # Environment
    p.add_argument("--K", type=int, default=5)
    p.add_argument("--M", type=int, default=3)
    p.add_argument("--arms-means", type=float, nargs="+",
                   default=[0.2,0.5,0.8,0.4,0.1], metavar="μ")
    p.add_argument("--bad-arm-target", type=int, default=0)
    p.add_argument("--top-seed", type=int, default=0)

    # PSRO loop
    p.add_argument("--iters", type=int, default=6)
    p.add_argument("--eval-episodes", type=int, default=200)

    # Alpha-Rank meta
    p.add_argument("--alpha-rank-strength", type=float, default=8.0)
    p.add_argument("--power-iters", type=int, default=2000)
    p.add_argument("--alpha-rank-tol", type=float, default=1e-12)
    p.add_argument("--inner-eval", type=int, default=8,
                   help="Rollouts per entry when estimating payoff matrix (unused by default)")

    # PPO oracle training
    p.add_argument("--oracle-epochs", type=int, default=200)
    p.add_argument("--oracle-lr", type=float, default=3e-4)
    p.add_argument("--ppo-batch", type=int, default=256)
    p.add_argument("--ppo-epochs", type=int, default=4)
    p.add_argument("--ppo-minibatch", type=int, default=64)
    p.add_argument("--ppo-clip", type=float, default=0.2)
    p.add_argument("--vf-coef", type=float, default=0.5)
    p.add_argument("--ent-coef", type=float, default=0.01)

    # Device & logging
    p.add_argument("--device", choices=["cpu","cuda"], default="cpu")
    p.add_argument("--log-csv", default=None,
                   help="Aggregated CSV filename (default alpha_psro_avg_<seedspan>.csv)")

    # Multi-seed controls
    p.add_argument("--seeds", type=int, nargs="+", default=None,
                   help="Explicit list of top seeds, e.g. --seeds 0 1 2 3 4")
    p.add_argument("--num-seeds", type=int, default=5,
                   help="If --seeds not given, use range(top_seed, top_seed+num_seeds)")

    # Plotting
    p.add_argument("--no-plot", action="store_true",
                   help="Disable matplotlib plotting")

    return p.parse_args()

# ────────────────────────────────────────────────────────────────────────────────
# RAM helper
# ────────────────────────────────────────────────────────────────────────────────
try:
    import psutil
    _proc = psutil.Process()
    def get_ram_mb(): return _proc.memory_info().rss / 1_048_576
except Exception:
    try:
        import resource
        def get_ram_mb():
            mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
            return mb/1_048_576 if sys.platform=="darwin" else mb/1024.0
    except Exception:
        def get_ram_mb(): return float("nan")

# ────────────────────────────────────────────────────────────────────────────────
# Networks
# ────────────────────────────────────────────────────────────────────────────────
class SenderPPO(nn.Module):
    def __init__(self,K,M):
        super().__init__()
        self.actor  = nn.Sequential(nn.Linear(K,32),nn.ReLU(),nn.Linear(32,M))
        self.critic = nn.Sequential(nn.Linear(K,32),nn.ReLU(),nn.Linear(32,1))
    def forward(self,x): return self.actor(x),self.critic(x).squeeze(-1)

class ReceiverPPO(nn.Module):
    def __init__(self,M,K):
        super().__init__()
        self.actor  = nn.Sequential(nn.Linear(M,32),nn.ReLU(),nn.Linear(32,K))
        self.critic = nn.Sequential(nn.Linear(M,32),nn.ReLU(),nn.Linear(32,1))
    def forward(self,x): return self.actor(x),self.critic(x).squeeze(-1)

# ────────────────────────────────────────────────────────────────────────────────
# Core helpers
# ────────────────────────────────────────────────────────────────────────────────
def set_seed(s): np.random.seed(s); torch.manual_seed(s)

def sample_episode(sender,receiver,env,device):
    K,M,BEST,BAD,MEANS=env
    s_obs=torch.zeros(K,device=device); s_obs[BEST]=1.0
    msg=Categorical(logits=sender(s_obs.unsqueeze(0))[0].squeeze(0)).sample()
    r_obs=torch.zeros(M,device=device); r_obs[msg]=1.0
    act=Categorical(logits=receiver(r_obs.unsqueeze(0))[0].squeeze(0)).sample()
    chosen=int(act); r_r=float(np.random.rand()<MEANS[chosen])
    s_r=float(chosen==BAD); return s_r,r_r

def payoff_matrix(S,R,n,seed,env,device):
    if seed is not None: set_seed(seed)
    sp,rp=np.zeros((len(S),len(R))),np.zeros((len(S),len(R)))
    for i,s in enumerate(S):
        for j,r in enumerate(R):
            ss=rr=0.0
            for _ in range(n):
                s_r,r_r=sample_episode(s,r,env,device); ss+=s_r; rr+=r_r
            sp[i,j]=ss/n; rp[i,j]=rr/n
    return sp,rp

def alpha_rank(sp,rp,alpha,strength_iters,tol):
    S,R=sp.shape; prof=[(i,j) for i in range(S) for j in range(R)]
    P=len(prof); T=np.zeros((P,P))
    for p,(i,j) in enumerate(prof):
        out=0.0
        for i2 in range(S):
            if i2==i: continue
            prob=0.5/(S-1)/(1+np.exp(-alpha*(sp[i2,j]-sp[i,j]))); T[p,prof.index((i2,j))]=prob; out+=prob
        for j2 in range(R):
            if j2==j: continue
            prob=0.5/(R-1)/(1+np.exp(-alpha*(rp[i,j2]-rp[i,j]))); T[p,prof.index((i,j2))]=prob; out+=prob
        T[p,p]=max(0.0,1-out)
    T=T/T.sum(axis=1,keepdims=True)
    v=np.ones(P)/P
    for _ in range(strength_iters):
        v_next=v@T
        if np.linalg.norm(v_next-v,1)<tol: break
        v=v_next
    s_mix=np.zeros(S); r_mix=np.zeros(R)
    for idx,mass in enumerate(v):
        i,j=prof[idx]; s_mix[i]+=mass; r_mix[j]+=mass
    s_mix=s_mix/s_mix.sum() if s_mix.sum()>0 else np.ones(S)/S
    r_mix=r_mix/r_mix.sum() if r_mix.sum()>0 else np.ones(R)/R
    return s_mix,r_mix

# ────────────────────────────────────────────────────────────────────────────────
# PPO training against mixture
# ────────────────────────────────────────────────────────────────────────────────
def train_ppo(policy_cls,opp_pop,opp_mix,role,args,env,device,seed_base):
    K,M,BEST,BAD,MEANS=env
    mix=opp_mix/opp_mix.sum() if opp_mix.sum()>0 else np.ones(len(opp_pop))/len(opp_pop)
    model=policy_cls(K,M).to(device) if role=="sender" else policy_cls(M,K).to(device)
    opt=optim.Adam(model.parameters(),lr=args.oracle_lr)
    for ep in range(1,args.oracle_epochs+1):
        set_seed(seed_base+ep)
        states,acts,old_lp,rets=[],[],[],[]
        for _ in range(args.ppo_batch):
            opp=opp_pop[np.random.choice(len(opp_pop),p=mix)]
            if role=="sender":
                s_obs=torch.zeros(K,device=device); s_obs[BEST]=1.0
                logits,_=model(s_obs.unsqueeze(0)); msg=Categorical(logits=logits.squeeze(0)).sample()
                logp=Categorical(logits=logits.squeeze(0)).log_prob(msg)
                r_obs=torch.zeros(M,device=device); r_obs[msg]=1.0
                act=Categorical(logits=opp(r_obs.unsqueeze(0))[0].squeeze(0)).sample()
                s_r=float(int(act)==BAD)
                states.append(s_obs.cpu()); acts.append(int(msg)); old_lp.append(float(logp)); rets.append(s_r)
            else:
                s_obs=torch.zeros(K,device=device); s_obs[BEST]=1.0
                msg=Categorical(logits=opp(s_obs.unsqueeze(0))[0].squeeze(0)).sample()
                r_obs=torch.zeros(M,device=device); r_obs[msg]=1.0
                logits,_=model(r_obs.unsqueeze(0)); dist=Categorical(logits=logits.squeeze(0))
                act=dist.sample(); logp=dist.log_prob(act)
                r_r=float(np.random.rand()<MEANS[int(act)])
                states.append(r_obs.cpu()); acts.append(int(act)); old_lp.append(float(logp)); rets.append(r_r)
        batch=torch.stack(states).to(device)
        acts_t=torch.tensor(acts,dtype=torch.int64,device=device)
        old_lp_t=torch.tensor(old_lp,dtype=torch.float32,device=device)
        returns=torch.tensor(rets,dtype=torch.float32,device=device)
        for _ in range(args.ppo_epochs):
            idx=np.random.permutation(args.ppo_batch)
            for st in range(0,args.ppo_batch,args.ppo_minibatch):
                mb=idx[st:st+args.ppo_minibatch]
                logits,val=model(batch[mb]); dist=Categorical(logits=logits)
                new_lp=dist.log_prob(acts_t[mb]); ent=dist.entropy().mean()
                adv=returns[mb]-val.detach()
                ratio=torch.exp(new_lp-old_lp_t[mb])
                pol=-torch.min(ratio*adv,
                               torch.clamp(ratio,1-args.ppo_clip,1+args.ppo_clip)*adv).mean()
                v=((val-returns[mb])**2).mean()
                loss=pol+args.vf_coefficient*v-args.ent_coef*ent if hasattr(args,"vf_coefficient") else pol+args.vf_coef*v-args.ent_coef*ent
                opt.zero_grad(); loss.backward(); opt.step()
    return model

# ────────────────────────────────────────────────────────────────────────────────
# Single-seed run
# ────────────────────────────────────────────────────────────────────────────────
def run_single_seed(args, base_seed: int, device: torch.device):
    ARMS=np.array(args.arms_means,dtype=np.float32); BEST=int(np.argmax(ARMS))
    env=(args.K,args.M,BEST,args.bad_arm_target,ARMS)
    set_seed(base_seed)

    S_pop=[SenderPPO(args.K,args.M).to(device)]
    R_pop=[ReceiverPPO(args.M,args.K).to(device)]
    hist_s,hist_r=[],[]
    log_t,log_time,log_ram=[],[],[]
    s_pop_hist,r_pop_hist=[],[]

    for it in range(1,args.psro_iters+1):
        t0=time.time(); log_t.append(time.strftime("%Y-%m-%d %H:%M:%S"))
        print(f"[seed {base_seed}] === Alpha-PSRO iter {it} ===")

        sp,rp=payoff_matrix(S_pop,R_pop,args.eval_episodes,
                            seed=base_seed+it,env=env,device=device)
        s_mix,r_mix=alpha_rank(sp,rp,args.alpha_rank_strength,
                               args.power_iters,args.alpha_rank_tol)

        # Add base_seed to training seeds to diversify trajectories across seeds
        S_pop.append(train_ppo(SenderPPO,R_pop,r_mix,"sender",args,env,device,
                               seed_base=1000*it + base_seed))
        R_pop.append(train_ppo(ReceiverPPO,S_pop[:-1],s_mix,"receiver",args,env,device,
                               seed_base=2000*it + base_seed))

        sp2,rp2=payoff_matrix(S_pop,R_pop,args.eval_episodes,
                              seed=base_seed+100+it,env=env,device=device)
        hist_s.append(sp2.mean()); hist_r.append(rp2.mean())
        log_time.append(time.time()-t0); log_ram.append(get_ram_mb())
        s_pop_hist.append(len(S_pop)); r_pop_hist.append(len(R_pop))

        print(f"[seed {base_seed}] iter {it:02d} | sender {hist_s[-1]:.4f} "
              f"receiver {hist_r[-1]:.4f} | S={len(S_pop)} R={len(R_pop)} "
              f"time {log_time[-1]:.1f}s RAM {log_ram[-1]:.1f}MB")

    return {
        "sender_mean": np.array(hist_s, dtype=np.float64),
        "receiver_mean": np.array(hist_r, dtype=np.float64),
        "time_sec": np.array(log_time, dtype=np.float64),
        "ram_mb": np.array(log_ram, dtype=np.float64),
        "timestamps": log_t,
        "sender_pop": np.array(s_pop_hist, dtype=np.int64),
        "receiver_pop": np.array(r_pop_hist, dtype=np.int64),
    }

# ────────────────────────────────────────────────────────────────────────────────
# Main
# ────────────────────────────────────────────────────────────────────────────────
def main():
    args=parse_args(); device=torch.device(args.device)

    # Determine seeds
    if args.seeds is not None and len(args.seeds) > 0:
        seeds = list(args.seeds)
    else:
        seeds = list(range(args.top_seed, args.top_seed + args.num_seeds))
    if len(seeds) == 0:
        raise ValueError("No seeds specified.")

    all_sender, all_receiver, all_time, all_ram = [], [], [], []
    pop_S = None; pop_R = None
    expected_iters = args.psro_iters

    for s in seeds:
        res = run_single_seed(args, base_seed=s, device=device)
        if len(res["sender_mean"]) != expected_iters:
            raise RuntimeError("Unexpected iteration length in a seed run.")
        all_sender.append(res["sender_mean"])
        all_receiver.append(res["receiver_mean"])
        all_time.append(res["time_sec"])
        all_ram.append(res["ram_mb"])
        if pop_S is None: pop_S = res["sender_pop"]
        if pop_R is None: pop_R = res["receiver_pop"]

    all_sender = np.stack(all_sender, axis=0)
    all_receiver = np.stack(all_receiver, axis=0)
    all_time = np.stack(all_time, axis=0)
    all_ram = np.stack(all_ram, axis=0)

    sender_mean_avg = all_sender.mean(axis=0)
    sender_mean_std = all_sender.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(sender_mean_avg)
    receiver_mean_avg = all_receiver.mean(axis=0)
    receiver_mean_std = all_receiver.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(receiver_mean_avg)

    time_mean = all_time.mean(axis=0)
    time_std  = all_time.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(time_mean)
    ram_mean  = all_ram.mean(axis=0)
    ram_std   = all_ram.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(ram_mean)

    # Plot
    if not args.no_plot:
        x = np.arange(1, expected_iters+1)
        plt.figure()
        plt.plot(x, sender_mean_avg, label="sender (mean)")
        plt.fill_between(x, sender_mean_avg - sender_mean_std, sender_mean_avg + sender_mean_std, alpha=0.2, label="sender (±1σ)")
        plt.plot(x, receiver_mean_avg, label="receiver (mean)")
        plt.fill_between(x, receiver_mean_avg - receiver_mean_std, receiver_mean_avg + receiver_mean_std, alpha=0.2, label="receiver (±1σ)")
        plt.xlabel("PSRO iter"); plt.ylabel("avg payoff")
        plt.grid(); plt.legend(); plt.title(f"Alpha-PSRO payoffs across {len(seeds)} seeds")
        plt.show()

    # CSV
    if args.log_csv is not None:
        csv_name = args.log_csv
    else:
        csv_name = f"alpha_psro_avg_{min(seeds)}to{max(seeds)}_n{len(seeds)}.csv" if len(seeds)>1 else f"alpha_psro_avg_seed{seeds[0]}.csv"

    with open(csv_name,"w",newline="") as f:
        w=csv.writer(f); w.writerow(
            ["iter","sender_mean_avg","sender_mean_std","receiver_mean_avg","receiver_mean_std",
             "time_sec_avg","time_sec_std","ram_mb_avg","ram_mb_std",
             "sender_pop","receiver_pop","seeds"])
        seed_str = ",".join(str(s) for s in seeds)
        for i in range(expected_iters):
            w.writerow([
                i+1,
                float(sender_mean_avg[i]), float(sender_mean_std[i]),
                float(receiver_mean_avg[i]), float(receiver_mean_std[i]),
                float(time_mean[i]), float(time_std[i]),
                float(ram_mean[i]), float(ram_std[i]),
                int(pop_S[i]), int(pop_R[i]),
                seed_str
            ])
    print(f"\nSaved aggregated instrumentation log to {csv_name}")

# ────────────────────────────────────────────────────────────────────────────────
if __name__=="__main__":
    main()
