import argparse, time, csv, sys
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.distributions import Categorical

# ────────────────────────────────────────────────────────────────────────────────
# Argument parser
# ────────────────────────────────────────────────────────────────────────────────
def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="One-step sender/receiver Double-Oracle with PPO BRs (multi-seed avg)"
    )

    # Environment
    p.add_argument("--K", type=int, default=5)
    p.add_argument("--M", type=int, default=3)
    p.add_argument("--arms-means", type=float, nargs="+",
                   default=[0.2,0.5,0.8,0.4,0.1], metavar="μ")
    p.add_argument("--bad-arm-target", type=int, default=0)
    p.add_argument("--top-seed", type=int, default=0)

    # Loop & evaluation
    p.add_argument("--iters", type=int, default=8,
                   help="Max Double-Oracle iterations (per seed)")
    p.add_argument("--eval-episodes", type=int, default=400)
    p.add_argument("--tol", type=float, default=1e-3,
                   help="Improvement threshold to accept a BR")

    # PPO params for best-response training
    p.add_argument("--oracle-epochs", type=int, default=200)
    p.add_argument("--oracle-lr", type=float, default=1e-3)
    p.add_argument("--ppo-batch", type=int, default=256)
    p.add_argument("--ppo-epochs", type=int, default=4)
    p.add_argument("--ppo-minibatch", type=int, default=64)
    p.add_argument("--ppo-clip", type=float, default=0.2)
    p.add_argument("--vf-coef", type=float, default=0.5)
    p.add_argument("--ent-coef", type=float, default=0.01)

    # Device & logging
    p.add_argument("--device", choices=["cpu","cuda"], default="cpu")
    p.add_argument("--log-csv", default=None,
                   help="Aggregated CSV (default double_oracle_avg_<seedspan>.csv)")
    p.add_argument("--no-plot", action="store_true",
                   help="Disable matplotlib plotting")

    # Multi-seed controls
    p.add_argument("--seeds", type=int, nargs="+", default=None,
                   help="Explicit list of top seeds, e.g. --seeds 0 1 2 3 4")
    p.add_argument("--num-seeds", type=int, default=5,
                   help="If --seeds not given, use range(top_seed, top_seed+num_seeds)")

    # Optional heartbeat logs
    p.add_argument("--log-interval", type=int, default=0,
                   help="Print PPO progress every N epochs (0=off)")

    return p.parse_args()

# ────────────────────────────────────────────────────────────────────────────────
# Helpers: RAM usage
# ────────────────────────────────────────────────────────────────────────────────
try:
    import psutil
    _proc = psutil.Process()
    def get_ram_mb(): return _proc.memory_info().rss / 1_048_576
except Exception:
    try:
        import resource
        def get_ram_mb():
            mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
            return mb/1_048_576 if sys.platform=="darwin" else mb/1024.0
    except Exception:
        def get_ram_mb(): return float("nan")

# ────────────────────────────────────────────────────────────────────────────────
# Networks
# ────────────────────────────────────────────────────────────────────────────────
class Sender(nn.Module):
    def __init__(self,K,M):
        super().__init__()
        self.net=nn.Sequential(nn.Linear(K,32),nn.ReLU(),nn.Linear(32,M))
    def forward(self,x): return self.net(x)

class Receiver(nn.Module):
    def __init__(self,M,K):
        super().__init__()
        self.net=nn.Sequential(nn.Linear(M,32),nn.ReLU(),nn.Linear(32,K))
    def forward(self,x): return self.net(x)

class SenderPPO(nn.Module):
    def __init__(self,K,M):
        super().__init__()
        self.actor = nn.Sequential(nn.Linear(K,32),nn.ReLU(),nn.Linear(32,M))
        self.critic= nn.Sequential(nn.Linear(K,32),nn.ReLU(),nn.Linear(32,1))
    def forward(self,x): return self.actor(x),self.critic(x).squeeze(-1)

class ReceiverPPO(nn.Module):
    def __init__(self,M,K):
        super().__init__()
        self.actor = nn.Sequential(nn.Linear(M,32),nn.ReLU(),nn.Linear(32,K))
        self.critic= nn.Sequential(nn.Linear(M,32),nn.ReLU(),nn.Linear(32,1))
    def forward(self,x): return self.actor(x),self.critic(x).squeeze(-1)

# ────────────────────────────────────────────────────────────────────────────────
# Core utilities
# ────────────────────────────────────────────────────────────────────────────────
def logits_from(agent,obs):
    out = agent(obs.unsqueeze(0))
    return out[0].squeeze(0) if isinstance(out, (tuple,list)) else out.squeeze(0)

def set_seed(s): np.random.seed(s); torch.manual_seed(s)

def sample_episode(sender,receiver,env,device):
    K,M,BEST,BAD,MEANS=env
    s_obs=torch.zeros(K,device=device); s_obs[BEST]=1.0
    msg=Categorical(logits=logits_from(sender,s_obs)).sample()
    r_obs=torch.zeros(M,device=device); r_obs[msg]=1.0
    act=Categorical(logits=logits_from(receiver,r_obs)).sample()
    chosen=int(act); r_r=float(np.random.rand()<MEANS[chosen])
    s_r=float(chosen==BAD)
    return s_r,r_r

def eval_payoffs(S,R,n_eval,seed,env,device):
    if seed is not None: set_seed(seed)
    sp,rp=np.zeros((len(S),len(R))),np.zeros((len(S),len(R)))
    for i,s in enumerate(S):
        for j,r in enumerate(R):
            ss=rr=0.0
            for _ in range(n_eval):
                s_r,r_r=sample_episode(s,r,env,device); ss+=s_r; rr+=r_r
            sp[i,j]=ss/n_eval; rp[i,j]=rr/n_eval
    return sp,rp

def fictitious_play(sp,rp,iters=500):
    S,R=sp.shape; s_cnt=np.ones(S); r_cnt=np.ones(R)
    for _ in range(iters):
        s_mix=s_cnt/s_cnt.sum(); r_mix=r_cnt/r_cnt.sum()
        s_cnt[np.argmax(sp.dot(r_mix))]+=1
        r_cnt[np.argmax(rp.T.dot(s_mix))]+=1
    return s_cnt/s_cnt.sum(), r_cnt/r_cnt.sum()

# ────────────────────────────────────────────────────────────────────────────────
# PPO BR trainers
# ────────────────────────────────────────────────────────────────────────────────
def train_sender_ppo(r_mix,R_pop,args,env,device,seed_base):
    K,M,BEST,_,MEANS=env
    r_mix=r_mix/r_mix.sum() if r_mix.sum()>0 else np.ones(len(R_pop))/len(R_pop)
    model=SenderPPO(K,M).to(device); opt=optim.Adam(model.parameters(),lr=args.oracle_lr)

    for ep in range(1,args.oracle_epochs+1):
        set_seed(seed_base+ep)
        states,acts,old_lp,rets=[],[],[],[]
        for _ in range(args.ppo_batch):
            recv=R_pop[np.random.choice(len(R_pop),p=r_mix)]
            s_obs=torch.zeros(K,device=device); s_obs[BEST]=1.0
            logits,val=model(s_obs.unsqueeze(0)); dist=Categorical(logits=logits.squeeze(0))
            msg=dist.sample(); logp=dist.log_prob(msg)

            r_obs=torch.zeros(M,device=device); r_obs[msg]=1.0
            act=Categorical(logits=logits_from(recv,r_obs)).sample()
            chosen=int(act); s_r=float(chosen==args.bad_arm_target)
            states.append(s_obs.cpu().numpy()); acts.append(int(msg))
            old_lp.append(float(logp)); rets.append(s_r)
        batch=torch.tensor(np.stack(states),dtype=torch.float32,device=device)
        acts_t=torch.tensor(acts,dtype=torch.int64,device=device)
        old_lp_t=torch.tensor(old_lp,dtype=torch.float32,device=device)
        returns=torch.tensor(rets,dtype=torch.float32,device=device)

        for _ in range(args.ppo_epochs):
            idx=np.random.permutation(args.ppo_batch)
            for st in range(0,args.ppo_batch,args.ppo_minibatch):
                mb=idx[st:st+args.ppo_minibatch]
                logits,val=model(batch[mb]); d=Categorical(logits=logits)
                new_lp=d.log_prob(acts_t[mb]); ent=d.entropy().mean()
                adv=returns[mb]-val.detach()
                ratio=torch.exp(new_lp-old_lp_t[mb])
                pol=-torch.min(ratio*adv,
                               torch.clamp(ratio,1-args.ppo_clip,1+args.ppo_clip)*adv).mean()
                v=((val-returns[mb])**2).mean()
                loss=pol+args.vf_coef*v-args.ent_coef*ent
                opt.zero_grad(); loss.backward(); opt.step()

        if args.log_interval and (ep==1 or ep%args.log_interval==0 or ep==args.oracle_epochs):
            print(f"[sender PPO] seed_base={seed_base} epoch {ep}/{args.oracle_epochs}", flush=True)

    return model

def train_receiver_ppo(s_mix,S_pop,args,env,device,seed_base):
    K,M,BEST,_,MEANS=env
    s_mix=s_mix/s_mix.sum() if s_mix.sum()>0 else np.ones(len(S_pop))/len(S_pop)
    model=ReceiverPPO(M,K).to(device); opt=optim.Adam(model.parameters(),lr=args.oracle_lr)

    for ep in range(1,args.oracle_epochs+1):
        set_seed(seed_base+ep)
        states,acts,old_lp,rets=[],[],[],[]
        for _ in range(args.ppo_batch):
            send=S_pop[np.random.choice(len(S_pop),p=s_mix)]
            s_obs=torch.zeros(K,device=device); s_obs[BEST]=1.0
            msg=Categorical(logits=logits_from(send,s_obs)).sample()
            r_obs=torch.zeros(M,device=device); r_obs[msg]=1.0
            logits,val=model(r_obs.unsqueeze(0)); dist=Categorical(logits=logits.squeeze(0))
            act=dist.sample(); logp=dist.log_prob(act)
            chosen=int(act); r_r=float(np.random.rand()<MEANS[chosen])
            states.append(r_obs.cpu().numpy()); acts.append(int(act))
            old_lp.append(float(logp)); rets.append(r_r)
        batch=torch.tensor(np.stack(states),dtype=torch.float32,device=device)
        acts_t=torch.tensor(acts,dtype=torch.int64,device=device)
        old_lp_t=torch.tensor(old_lp,dtype=torch.float32,device=device)
        returns=torch.tensor(rets,dtype=torch.float32,device=device)

        for _ in range(args.ppo_epochs):
            idx=np.random.permutation(args.ppo_batch)
            for st in range(0,args.ppo_batch,args.ppo_minibatch):
                mb=idx[st:st+args.ppo_minibatch]
                logits,val=model(batch[mb]); d=Categorical(logits=logits)
                new_lp=d.log_prob(acts_t[mb]); ent=d.entropy().mean()
                adv=returns[mb]-val.detach()
                ratio=torch.exp(new_lp-old_lp_t[mb])
                pol=-torch.min(ratio*adv,
                               torch.clamp(ratio,1-args.ppo_clip,1+args.ppo_clip)*adv).mean()
                v=((val-returns[mb])**2).mean()
                loss=pol+args.vf_coef*v-args.ent_coef*ent
                opt.zero_grad(); loss.backward(); opt.step()

        if args.log_interval and (ep==1 or ep%args.log_interval==0 or ep==args.oracle_epochs):
            print(f"[receiver PPO] seed_base={seed_base} epoch {ep}/{args.oracle_epochs}", flush=True)

    return model

# ────────────────────────────────────────────────────────────────────────────────
# Single-seed run (with early stop)
# ────────────────────────────────────────────────────────────────────────────────
def run_single_seed(args, base_seed: int, device: torch.device):
    ARMS=np.array(args.arms_means,dtype=np.float32)
    BEST=int(np.argmax(ARMS))
    env=(args.K,args.M,BEST,args.bad_arm_target,ARMS)

    set_seed(base_seed)
    sender_pop=[Sender(args.K,args.M).to(device)]
    receiver_pop=[Receiver(args.M,args.K).to(device)]

    hist_s,hist_r=[],[]
    log_times,log_ram,log_ts=[],[],[]
    log_added,log_s_pop,log_r_pop=[],[],[]

    for it in range(1,args.do_iters+1):
        t0=time.time(); log_ts.append(time.strftime("%Y-%m-%d %H:%M:%S"))
        print(f"[seed {base_seed}] === DO iter {it} ===", flush=True)

        sp,rp=eval_payoffs(sender_pop,receiver_pop,args.eval_episodes,
                           seed=base_seed+it,env=env,device=device)
        s_mix,r_mix=fictitious_play(sp,rp)
        val_s= (s_mix@(sp@r_mix)); val_r=(r_mix@(rp.T@s_mix))

        br_s=train_sender_ppo(r_mix,receiver_pop,args,env,device,seed_base=1000*it + base_seed)
        br_r=train_receiver_ppo(s_mix,sender_pop,args,env,device,seed_base=2000*it + base_seed)

        val_s_br=eval_payoffs([br_s],receiver_pop,args.eval_episodes,
                              seed=base_seed+1000+it,env=env,device=device)[0].dot(r_mix)[0]
        val_r_br=eval_payoffs(sender_pop,[br_r],args.eval_episodes,
                              seed=base_seed+2000+it,env=env,device=device)[1].T.dot(s_mix)[0]

        added=0
        if val_s_br>val_s+args.tol: sender_pop.append(br_s); added+=1
        if val_r_br>val_r+args.tol: receiver_pop.append(br_r); added+=1

        sp2,rp2=eval_payoffs(sender_pop,receiver_pop,args.eval_episodes,
                             seed=base_seed+5000+it,env=env,device=device)
        hist_s.append(sp2.mean()); hist_r.append(rp2.mean())

        log_times.append(time.time()-t0); log_ram.append(get_ram_mb())
        log_added.append(added); log_s_pop.append(len(sender_pop)); log_r_pop.append(len(receiver_pop))

        print(f"[seed {base_seed}] iter {it:02d} | added {added} | "
              f"sender {hist_s[-1]:.4f} receiver {hist_r[-1]:.4f} | "
              f"S={log_s_pop[-1]} R={log_r_pop[-1]}", flush=True)

        if added==0:
            print(f"[seed {base_seed}] No improving BRs – stopping.", flush=True)
            break

    return {
        "sender_mean": np.array(hist_s, dtype=np.float64),
        "receiver_mean": np.array(hist_r, dtype=np.float64),
        "time_sec": np.array(log_times, dtype=np.float64),
        "ram_mb": np.array(log_ram, dtype=np.float64),
        "timestamps": log_ts,
        "added_brs": np.array(log_added, dtype=np.int64),
        "sender_pop": np.array(log_s_pop, dtype=np.float64),
        "receiver_pop": np.array(log_r_pop, dtype=np.float64),
        "stop_iter": len(hist_s),
    }

# ────────────────────────────────────────────────────────────────────────────────
# Main (multi-seed aggregation with early-stop handling)
# ────────────────────────────────────────────────────────────────────────────────
def main():
    args=parse_args(); device=torch.device(args.device)

    # Determine seeds
    if args.seeds is not None and len(args.seeds) > 0:
        seeds = list(args.seeds)
    else:
        seeds = list(range(args.top_seed, args.top_seed + args.num_seeds))
    if len(seeds) == 0:
        raise ValueError("No seeds specified.")

    results = []
    max_iters = 0
    for s in seeds:
        res = run_single_seed(args, base_seed=s, device=device)
        results.append(res)
        max_iters = max(max_iters, len(res["sender_mean"]))

    # Build NaN-padded arrays [n_seeds, max_iters]
    def pad_stack(key, dtype=float):
        arrs=[]
        for r in results:
            x = r[key]
            if len(x) < max_iters:
                pad = np.full(max_iters - len(x), np.nan, dtype=dtype)
                x = np.concatenate([x, pad])
            arrs.append(x)
        return np.stack(arrs, axis=0)  # [n_seeds, max_iters]

    S = pad_stack("sender_mean")
    R = pad_stack("receiver_mean")
    T = pad_stack("time_sec")
    M = pad_stack("ram_mb")
    S_pop = pad_stack("sender_pop")
    R_pop = pad_stack("receiver_pop")
    Added = pad_stack("added_brs", dtype=float)  # ints -> float for NaN

    # Effective seed count per iteration (exclude NaNs)
    seeds_eff = np.sum(~np.isnan(S), axis=0).astype(int)

    # Aggregate with nan-aware stats
    sender_mean_avg = np.nanmean(S, axis=0)
    sender_mean_std = np.nanstd(S, axis=0, ddof=1)
    receiver_mean_avg = np.nanmean(R, axis=0)
    receiver_mean_std = np.nanstd(R, axis=0, ddof=1)

    time_mean = np.nanmean(T, axis=0)
    time_std  = np.nanstd(T, axis=0, ddof=1)
    ram_mean  = np.nanmean(M, axis=0)
    ram_std   = np.nanstd(M, axis=0, ddof=1)

    sender_pop_avg = np.nanmean(S_pop, axis=0)
    sender_pop_std = np.nanstd(S_pop, axis=0, ddof=1)
    receiver_pop_avg = np.nanmean(R_pop, axis=0)
    receiver_pop_std = np.nanstd(R_pop, axis=0, ddof=1)

    added_avg = np.nanmean(Added, axis=0)
    added_std = np.nanstd(Added, axis=0, ddof=1)

    # Plot (only where we have data)
    if not args.no_plot:
        x = np.arange(1, max_iters+1)
        mask = seeds_eff > 0
        plt.figure()
        plt.plot(x[mask], sender_mean_avg[mask], label="sender (mean)")
        plt.fill_between(x[mask],
                         sender_mean_avg[mask] - sender_mean_std[mask],
                         sender_mean_avg[mask] + sender_mean_std[mask],
                         alpha=0.2, label="sender (±1σ)")
        plt.plot(x[mask], receiver_mean_avg[mask], label="receiver (mean)")
        plt.fill_between(x[mask],
                         receiver_mean_avg[mask] - receiver_mean_std[mask],
                         receiver_mean_avg[mask] + receiver_mean_std[mask],
                         alpha=0.2, label="receiver (±1σ)")
        plt.xlabel("Double-Oracle iteration")
        plt.ylabel("avg payoff")
        plt.grid(); plt.legend()
        plt.title(f"Double-Oracle payoffs across {len(seeds)} seeds")
        plt.show()

    # CSV
    if args.log_csv is not None:
        csv_name = args.log_csv
    else:
        csv_name = (
            f"double_oracle_avg_{min(seeds)}to{max(seeds)}_n{len(seeds)}.csv"
            if len(seeds)>1 else f"double_oracle_avg_seed{seeds[0]}.csv"
        )

    with open(csv_name,"w",newline="") as f:
        w=csv.writer(f)
        w.writerow([
            "iter","seeds_eff",
            "sender_mean_avg","sender_mean_std",
            "receiver_mean_avg","receiver_mean_std",
            "time_sec_avg","time_sec_std",
            "ram_mb_avg","ram_mb_std",
            "added_brs_avg","added_brs_std",
            "sender_pop_avg","sender_pop_std",
            "receiver_pop_avg","receiver_pop_std",
            "seeds"
        ])
        seed_str = ",".join(str(s) for s in seeds)
        for i in range(max_iters):
            if seeds_eff[i] == 0:
                continue  # no data at this iteration
            w.writerow([
                i+1, int(seeds_eff[i]),
                float(sender_mean_avg[i]), float(sender_mean_std[i]),
                float(receiver_mean_avg[i]), float(receiver_mean_std[i]),
                float(time_mean[i]), float(time_std[i]),
                float(ram_mean[i]), float(ram_std[i]),
                float(added_avg[i]), float(added_std[i]),
                float(sender_pop_avg[i]), float(sender_pop_std[i]),
                float(receiver_pop_avg[i]), float(receiver_pop_std[i]),
                seed_str
            ])
    print(f"\nSaved aggregated instrumentation log to {csv_name}", flush=True)

# ────────────────────────────────────────────────────────────────────────────────
if __name__=="__main__":
    main()
