import argparse, time, csv, sys
from copy import deepcopy
import numpy as np
import torch, torch.nn as nn, torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt

# ────────────────────────────────────────────────────────────────────────────────
# Arg-parser
# ────────────────────────────────────────────────────────────────────────────────
def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="One-step A-PSRO with PPO candidates/oracles (multi-seed avg)"
    )

    # Environment
    p.add_argument("--K", type=int, default=5)
    p.add_argument("--M", type=int, default=3)
    p.add_argument("--arms-means", type=float, nargs="+",
                   default=[0.2,0.5,0.8,0.4,0.1], metavar="μ")
    p.add_argument("--bad-arm-target", type=int, default=0)
    p.add_argument("--top-seed", type=int, default=0)

    # Main loop
    p.add_argument("--iters", type=int, default=6)
    p.add_argument("--fp-iters", type=int, default=200,
                   help="Fictitious-play iterations for meta-solver")
    p.add_argument("--num-repeats", type=int, default=3,
                   help="Number of A-PSRO exploration repeats per outer iter")
    p.add_argument("--lookahead-d", type=float, default=0.1,
                   help="Lookahead interpolation coefficient d")

    # Candidate BR training
    p.add_argument("--cand-epochs", type=int, default=40)
    p.add_argument("--cand-lr", type=float, default=3e-4)

    # Final oracle training
    p.add_argument("--oracle-epochs", type=int, default=120)
    p.add_argument("--oracle-lr", type=float, default=3e-4)

    # PPO params
    p.add_argument("--ppo-batch", type=int, default=128)
    p.add_argument("--ppo-epochs", type=int, default=2)
    p.add_argument("--ppo-minibatch", type=int, default=64)
    p.add_argument("--ppo-clip", type=float, default=0.2)
    p.add_argument("--vf-coef", type=float, default=0.5)
    p.add_argument("--ent-coef", type=float, default=0.01)

    # Eval & device
    p.add_argument("--eval-episodes", type=int, default=100)
    p.add_argument("--device", choices=["cpu","cuda"], default="cpu")
    p.add_argument("--log-csv", default=None,
                   help="Aggregated CSV filename (default a_psro_avg_<seedspan>.csv)")

    # Multi-seed controls
    p.add_argument("--seeds", type=int, nargs="+", default=None,
                   help="Explicit list of top seeds, e.g. --seeds 0 1 2 3 4")
    p.add_argument("--num-seeds", type=int, default=5,
                   help="If --seeds not given, use range(top_seed, top_seed+num_seeds)")

    # Plotting
    p.add_argument("--no-plot", action="store_true",
                   help="Disable matplotlib plotting")
    return p.parse_args()

# ────────────────────────────────────────────────────────────────────────────────
# RAM helper
# ────────────────────────────────────────────────────────────────────────────────
try:
    import psutil
    _proc=psutil.Process()
    def get_ram_mb(): return _proc.memory_info().rss/1_048_576
except Exception:
    try:
        import resource
        def get_ram_mb():
            mb=resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
            return mb/1_048_576 if sys.platform=="darwin" else mb/1024.0
    except Exception:
        def get_ram_mb(): return float("nan")

# ────────────────────────────────────────────────────────────────────────────────
# Networks
# ────────────────────────────────────────────────────────────────────────────────
class SenderPPO(nn.Module):
    def __init__(self,K,M):
        super().__init__()
        self.actor=nn.Sequential(nn.Linear(K,32),nn.ReLU(),nn.Linear(32,M))
        self.critic=nn.Sequential(nn.Linear(K,32),nn.ReLU(),nn.Linear(32,1))
    def forward(self,x): return self.actor(x),self.critic(x).squeeze(-1)

class ReceiverPPO(nn.Module):
    def __init__(self,M,K):
        super().__init__()
        self.actor=nn.Sequential(nn.Linear(M,32),nn.ReLU(),nn.Linear(32,K))
        self.critic=nn.Sequential(nn.Linear(M,32),nn.ReLU(),nn.Linear(32,1))
    def forward(self,x): return self.actor(x),self.critic(x).squeeze(-1)

# ────────────────────────────────────────────────────────────────────────────────
# Core helpers
# ────────────────────────────────────────────────────────────────────────────────
def set_seed(s): np.random.seed(s); torch.manual_seed(s)

def sample_episode(sender,receiver,env,device):
    K,M,BEST,BAD,MEANS=env
    s_obs=torch.zeros(K,device=device); s_obs[BEST]=1.0
    msg=Categorical(logits=sender(s_obs.unsqueeze(0))[0].squeeze(0)).sample()
    r_obs=torch.zeros(M,device=device); r_obs[msg]=1.0
    act=Categorical(logits=receiver(r_obs.unsqueeze(0))[0].squeeze(0)).sample()
    chosen=int(act); r_r=float(np.random.rand()<MEANS[chosen])
    s_r=float(chosen==BAD); return s_r,r_r

def eval_payoffs(S,R,n,seed,env,device):
    if seed is not None: set_seed(seed)
    sp,rp=np.zeros((len(S),len(R))),np.zeros((len(S),len(R)))
    for i,s in enumerate(S):
        for j,r in enumerate(R):
            ss=rr=0.0
            for _ in range(n):
                s_r,r_r=sample_episode(s,r,env,device); ss+=s_r; rr+=r_r
            sp[i,j]=ss/n; rp[i,j]=rr/n
    return sp,rp

def fictitious_play(sp,rp,iters,init_s=None,init_r=None):
    S,R=sp.shape
    s_cnt=np.ones(S); r_cnt=np.ones(R)
    if init_s is not None: s_cnt+=np.asarray(init_s)*100
    if init_r is not None: r_cnt+=np.asarray(init_r)*100
    for _ in range(iters):
        s_mix=s_cnt/s_cnt.sum(); r_mix=r_cnt/r_cnt.sum()
        s_cnt[np.argmax(sp.dot(r_mix))]+=1
        r_cnt[np.argmax(rp.T.dot(s_mix))]+=1
    return s_cnt/s_cnt.sum(),r_cnt/r_cnt.sum()

# ────────────────────────────────────────────────────────────────────────────────
# PPO training against a mixture (shared by sender & receiver)
# ────────────────────────────────────────────────────────────────────────────────
def train_ppo(policy_cls,opp_pop,opp_mix,role,args,env,device,seed_base,epochs,lr):
    K,M,BEST,BAD,MEANS=env
    opp_mix=opp_mix/opp_mix.sum() if opp_mix.sum()>0 else np.ones(len(opp_pop))/len(opp_pop)
    model=policy_cls(K,M).to(device) if role=="sender" else policy_cls(M,K).to(device)
    opt=optim.Adam(model.parameters(),lr=lr)

    for ep in range(1,epochs+1):
        set_seed(seed_base+ep)
        states,acts,old_lp,rets=[],[],[],[]
        for _ in range(args.ppo_batch):
            opp=opp_pop[np.random.choice(len(opp_pop),p=opp_mix)]
            if role=="sender":       # model acts as sender
                s_obs=torch.zeros(K,device=device); s_obs[BEST]=1.0
                logits,_=model(s_obs.unsqueeze(0)); dist=Categorical(logits=logits.squeeze(0))
                msg=dist.sample(); logp=dist.log_prob(msg)
                r_obs=torch.zeros(M,device=device); r_obs[msg]=1.0
                act=Categorical(logits=opp(r_obs.unsqueeze(0))[0].squeeze(0)).sample()
                s_r=float(int(act)==BAD)
                states.append(s_obs.cpu().numpy()); acts.append(int(msg)); old_lp.append(float(logp)); rets.append(s_r)
            else:                    # model acts as receiver
                s_obs=torch.zeros(K,device=device); s_obs[BEST]=1.0
                msg=Categorical(logits=opp(s_obs.unsqueeze(0))[0].squeeze(0)).sample()
                r_obs=torch.zeros(M,device=device); r_obs[msg]=1.0
                logits,_=model(r_obs.unsqueeze(0)); dist=Categorical(logits=logits.squeeze(0))
                act=dist.sample(); logp=dist.log_prob(act)
                chosen=int(act); r_r=float(np.random.rand()<MEANS[chosen])
                states.append(r_obs.cpu().numpy()); acts.append(int(act)); old_lp.append(float(logp)); rets.append(r_r)

        batch=torch.tensor(np.stack(states),dtype=torch.float32,device=device)
        acts_t=torch.tensor(acts,dtype=torch.int64,device=device)
        old_lp_t=torch.tensor(old_lp,dtype=torch.float32,device=device)
        returns=torch.tensor(rets,dtype=torch.float32,device=device)

        for _ in range(args.ppo_epochs):
            idx=np.random.permutation(args.ppo_batch)
            for st in range(0,args.ppo_batch,args.ppo_minibatch):
                mb=idx[st:st+args.ppo_minibatch]
                logits,val=model(batch[mb]); d=Categorical(logits=logits)
                new_lp=d.log_prob(acts_t[mb]); ent=d.entropy().mean()
                adv=returns[mb]-val.detach()
                ratio=torch.exp(new_lp-old_lp_t[mb])
                pol=-torch.min(ratio*adv,
                               torch.clamp(ratio,1-args.ppo_clip,1+args.ppo_clip)*adv).mean()
                v=((val-returns[mb])**2).mean()
                loss=pol+args.vf_coef*v-args.ent_coef*ent
                opt.zero_grad(); loss.backward(); opt.step()
    return model

# ────────────────────────────────────────────────────────────────────────────────
# Single-seed run
# ────────────────────────────────────────────────────────────────────────────────
def run_single_seed(args, base_seed: int, device: torch.device):
    ARMS=np.array(args.arms_means,dtype=np.float32); BEST=int(np.argmax(ARMS))
    env=(args.K,args.M,BEST,args.bad_arm_target,ARMS)
    set_seed(base_seed)

    S_pop=[SenderPPO(args.K,args.M).to(device)]
    R_pop=[ReceiverPPO(args.M,args.K).to(device)]
    hist_s,hist_r=[],[]
    log_t,log_time,log_ram=[],[],[]
    s_pop_hist,r_pop_hist=[],[]

    for it in range(1,args.psro_iters+1):
        t0=time.time(); log_t.append(time.strftime("%Y-%m-%d %H:%M:%S"))
        print(f"[seed {base_seed}] === A-PSRO iter {it} ===")

        sp,rp=eval_payoffs(S_pop,R_pop,args.eval_episodes,
                           seed=base_seed+it,env=env,device=device)

        best_cand_s,best_adv_s,best_theta_r=None,-np.inf,None
        best_cand_r,best_adv_r,best_theta_s=None,-np.inf,None

        for rep in range(args.num_repeats):
            init_s=np.random.dirichlet(np.ones(len(S_pop)))
            init_r=np.random.dirichlet(np.ones(len(R_pop)))
            theta_s,theta_r=fictitious_play(sp,rp,args.fp_iters,
                                            init_s,init_r)

            # sender candidate
            cand_s=train_ppo(SenderPPO,R_pop,theta_r,"sender",args,env,device,
                             seed_base=1000*it+rep + base_seed,
                             epochs=args.cand_epochs,lr=args.cand_lr)
            s_vs=[sample_episode(cand_s,r,env,device)[0] for r in R_pop]
            sender_mix=(1-args.lookahead_d)*theta_s.dot(sp)+args.lookahead_d*np.array(s_vs)
            rec_pay_mix=(1-args.lookahead_d)*theta_s.dot(rp)+args.lookahead_d*np.array(
                         [sample_episode(cand_s,r,env,device)[1] for r in R_pop])
            adv=max(sender_mix[np.isclose(rec_pay_mix,rec_pay_mix.max())])
            if adv>best_adv_s:
                best_adv_s, best_cand_s, best_theta_r = adv, cand_s, theta_r.copy()

            # receiver candidate
            cand_r=train_ppo(ReceiverPPO,S_pop,theta_s,"receiver",args,env,device,
                             seed_base=2000*it+rep + base_seed,
                             epochs=args.cand_epochs,lr=args.cand_lr)
            r_vs=[sample_episode(s,cand_r,env,device)[1] for s in S_pop]
            send_pay_mix=(1-args.lookahead_d)*sp.dot(theta_r)+args.lookahead_d*np.array(
                          [sample_episode(s,cand_r,env,device)[0] for s in S_pop])
            rec_mix=(1-args.lookahead_d)*rp.T.dot(theta_r)+args.lookahead_d*np.array(r_vs)
            adv_r=max(rec_mix[np.isclose(send_pay_mix,send_pay_mix.max())])
            if adv_r>best_adv_r:
                best_adv_r,best_cand_r,best_theta_s=adv_r,cand_r,theta_s.copy()

        # final oracles
        final_s=train_ppo(SenderPPO,R_pop,best_theta_r,"sender",args,env,device,
                          seed_base=3000*it + base_seed,epochs=args.oracle_epochs,lr=args.oracle_lr)
        final_r=train_ppo(ReceiverPPO,S_pop,best_theta_s,"receiver",args,env,device,
                          seed_base=4000*it + base_seed,epochs=args.oracle_epochs,lr=args.oracle_lr)
        S_pop.append(final_s); R_pop.append(final_r)

        sp2,rp2=eval_payoffs(S_pop,R_pop,args.eval_episodes,
                             seed=base_seed+100+it,env=env,device=device)
        hist_s.append(sp2.mean()); hist_r.append(rp2.mean())
        log_time.append(time.time()-t0); log_ram.append(get_ram_mb())
        s_pop_hist.append(len(S_pop)); r_pop_hist.append(len(R_pop))

        print(f"[seed {base_seed}] iter {it:02d} | sender {hist_s[-1]:.4f} "
              f"receiver {hist_r[-1]:.4f} | S={len(S_pop)} R={len(R_pop)} "
              f"| time {log_time[-1]:.1f}s RAM {log_ram[-1]:.1f}MB")

    return {
        "sender_mean": np.array(hist_s, dtype=np.float64),
        "receiver_mean": np.array(hist_r, dtype=np.float64),
        "time_sec": np.array(log_time, dtype=np.float64),
        "ram_mb": np.array(log_ram, dtype=np.float64),
        "timestamps": log_t,
        "sender_pop": np.array(s_pop_hist, dtype=np.int64),
        "receiver_pop": np.array(r_pop_hist, dtype=np.int64),
    }

# ────────────────────────────────────────────────────────────────────────────────
# Main
# ────────────────────────────────────────────────────────────────────────────────
def main():
    args=parse_args(); device=torch.device(args.device)

    # Determine seeds
    if args.seeds is not None and len(args.seeds) > 0:
        seeds = list(args.seeds)
    else:
        seeds = list(range(args.top_seed, args.top_seed + args.num_seeds))
    if len(seeds) == 0:
        raise ValueError("No seeds specified.")

    all_sender, all_receiver, all_time, all_ram = [], [], [], []
    pop_S = None; pop_R = None
    expected_iters = args.psro_iters

    for s in seeds:
        res = run_single_seed(args, base_seed=s, device=device)
        if len(res["sender_mean"]) != expected_iters:
            raise RuntimeError("Unexpected iteration length in a seed run.")
        all_sender.append(res["sender_mean"])
        all_receiver.append(res["receiver_mean"])
        all_time.append(res["time_sec"])
        all_ram.append(res["ram_mb"])
        if pop_S is None: pop_S = res["sender_pop"]
        if pop_R is None: pop_R = res["receiver_pop"]

    all_sender = np.stack(all_sender, axis=0)
    all_receiver = np.stack(all_receiver, axis=0)
    all_time = np.stack(all_time, axis=0)
    all_ram = np.stack(all_ram, axis=0)

    sender_mean_avg = all_sender.mean(axis=0)
    sender_mean_std = all_sender.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(sender_mean_avg)
    receiver_mean_avg = all_receiver.mean(axis=0)
    receiver_mean_std = all_receiver.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(receiver_mean_avg)

    time_mean = all_time.mean(axis=0)
    time_std  = all_time.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(time_mean)
    ram_mean  = all_ram.mean(axis=0)
    ram_std   = all_ram.std(axis=0, ddof=1) if len(seeds) > 1 else np.zeros_like(ram_mean)

    # Plot
    if not args.no_plot:
        x = np.arange(1, expected_iters+1)
        plt.figure()
        plt.plot(x, sender_mean_avg, label="sender (mean)")
        plt.fill_between(x, sender_mean_avg - sender_mean_std, sender_mean_avg + sender_mean_std, alpha=0.2, label="sender (±1σ)")
        plt.plot(x, receiver_mean_avg, label="receiver (mean)")
        plt.fill_between(x, receiver_mean_avg - receiver_mean_std, receiver_mean_avg + receiver_mean_std, alpha=0.2, label="receiver (±1σ)")
        plt.xlabel("A-PSRO iter"); plt.ylabel("avg payoff")
        plt.grid(); plt.legend(); plt.title(f"A-PSRO payoffs across {len(seeds)} seeds")
        plt.show()

    # CSV
    if args.log_csv is not None:
        csv_name = args.log_csv
    else:
        csv_name = f"a_psro_avg_{min(seeds)}to{max(seeds)}_n{len(seeds)}.csv" if len(seeds)>1 else f"a_psro_avg_seed{seeds[0]}.csv"

    with open(csv_name,"w",newline="") as f:
        w=csv.writer(f); w.writerow(
            ["iter","sender_mean_avg","sender_mean_std","receiver_mean_avg","receiver_mean_std",
             "time_sec_avg","time_sec_std","ram_mb_avg","ram_mb_std",
             "sender_pop","receiver_pop","seeds"])
        seed_str = ",".join(str(s) for s in seeds)
        for i in range(expected_iters):
            w.writerow([
                i+1,
                float(sender_mean_avg[i]), float(sender_mean_std[i]),
                float(receiver_mean_avg[i]), float(receiver_mean_std[i]),
                float(time_mean[i]), float(time_std[i]),
                float(ram_mean[i]), float(ram_std[i]),
                int(pop_S[i]), int(pop_R[i]),
                seed_str
            ])
    print(f"\nSaved aggregated instrumentation log to {csv_name}")

# ────────────────────────────────────────────────────────────────────────────────
if __name__=="__main__":
    main()
