# algos.py
import math, numpy as np, torch
import torch.nn as nn, torch.nn.functional as F

# ---------- 通用 ----------
def mlp(sizes, act=nn.ReLU, out_act=nn.Identity):
    layers=[]
    for i in range(len(sizes)-1):
        layers += [nn.Linear(sizes[i], sizes[i+1]),
                   act() if i < len(sizes)-2 else out_act()]
    return nn.Sequential(*layers)

class RunningNorm:
    def __init__(self, shape, eps=1e-5, clip=10.0):
        self.mean = np.zeros(shape, dtype=np.float64)
        self.var  = np.ones(shape, dtype=np.float64)
        self.count= eps; self.clip=clip
    def update(self, x):
        x = x.astype(np.float64)
        m, v, n = x.mean(0), x.var(0), x.shape[0]
        delta = m - self.mean; tot = self.count + n
        new_mean = self.mean + delta * n / tot
        m_a = self.var * self.count; m_b = v * n
        M2 = m_a + m_b + delta**2 * self.count * n / tot
        self.mean, self.var, self.count = new_mean, M2/tot, tot
    def normalize(self, x):
        x = (x - self.mean) / (np.sqrt(self.var)+1e-8)
        return np.clip(x, -self.clip, self.clip)

class ReplayBuffer:
    def __init__(self, obs_dim, act_dim, size:int):
        self.obs  = np.zeros((size, obs_dim), np.float32)
        self.obs2 = np.zeros((size, obs_dim), np.float32)
        self.act  = np.zeros((size, act_dim), np.float32)
        self.rew  = np.zeros((size, 1), np.float32)
        self.done = np.zeros((size, 1), np.float32)
        self.ptr=0; self.size=0; self.max=size
    def store(self, o,a,r,o2,d):
        i=self.ptr
        self.obs[i]=o; self.act[i]=a; self.rew[i]=r
        self.obs2[i]=o2; self.done[i]=d
        self.ptr=(self.ptr+1)%self.max; self.size=min(self.size+1,self.max)
    def sample(self, bs=256):
        idx = np.random.randint(0, self.size, size=bs)
        to_t = lambda x: torch.as_tensor(x[idx], dtype=torch.float32)
        return dict(obs=to_t(self.obs), obs2=to_t(self.obs2),
                    act=to_t(self.act), rew=to_t(self.rew), done=to_t(self.done))

# ---------- PPO ----------
def discount_cumsum(x, gamma):
    y=np.zeros_like(x, np.float32); run=0.0
    for t in reversed(range(len(x))):
        run = x[t] + gamma*run; y[t]=run
    return y

class PPOBuffer:
    def __init__(self, obs_dim, act_dim, size, gamma=0.99, lam=0.95):
        self.obs=np.zeros((size,obs_dim),np.float32)
        self.act=np.zeros((size,act_dim),np.float32)
        self.adv=np.zeros((size,1),np.float32)
        self.rew=np.zeros((size,1),np.float32)
        self.ret=np.zeros((size,1),np.float32)
        self.val=np.zeros((size,1),np.float32)
        self.logp=np.zeros((size,1),np.float32)
        self.gamma=gamma; self.lam=lam
        self.ptr=0; self.path_start=0; self.max=size
    def store(self,o,a,r,v,lp):
        i=self.ptr; self.ptr+=1
        self.obs[i]=o; self.act[i]=a; self.rew[i]=r; self.val[i]=v; self.logp[i]=lp
    def finish_path(self, last_val=0.0):
        sl=slice(self.path_start,self.ptr)
        rews=np.append(self.rew[sl],[[0]],0)
        vals=np.append(self.val[sl],[[last_val]],0)
        deltas = rews[:-1] + self.gamma*vals[1:] - vals[:-1]
        self.adv[sl]=discount_cumsum(deltas, self.gamma*self.lam)
        self.ret[sl]=discount_cumsum(rews, self.gamma)[:-1]
        self.path_start=self.ptr
    def get(self):
        assert self.ptr==self.max
        self.ptr=0; self.path_start=0
        adv=(self.adv - self.adv.mean())/(self.adv.std()+1e-8)
        to_t=lambda x: torch.as_tensor(x, dtype=torch.float32)
        return dict(obs=to_t(self.obs), act=to_t(self.act),
                    ret=to_t(self.ret), adv=to_t(adv), logp=to_t(self.logp))

class PPOActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim, hid=(256,256)):
        super().__init__()
        self.pi = mlp([obs_dim,*hid,act_dim], act=nn.Tanh, out_act=nn.Identity)
        self.log_std = nn.Parameter(torch.zeros(act_dim)-0.5)
        self.vf = mlp([obs_dim,*hid,1], act=nn.Tanh, out_act=nn.Identity)
    def step(self, obs):
        with torch.no_grad():
            mu = self.pi(obs); std = torch.exp(self.log_std)
            dist=torch.distributions.Normal(mu,std)
            u=dist.sample(); a=torch.tanh(u)
            logp=(dist.log_prob(torch.atanh(a)) - torch.log(1-a.pow(2)+1e-6)).sum(-1,True)
            v=self.vf(obs)
        return a, v, logp
    def act_mean(self, obs): return torch.tanh(self.pi(obs))
    def evaluate(self, obs, act):
        mu=self.pi(obs); std=torch.exp(self.log_std)
        dist=torch.distributions.Normal(mu,std)
        u=torch.atanh(torch.clamp(act,-0.999,0.999))
        logp=(dist.log_prob(u)-torch.log(1-act.pow(2)+1e-6)).sum(-1,True)
        ent=dist.entropy().sum(-1,True); v=self.vf(obs)
        return logp, ent, v

class PPOAgent:
    def __init__(self, obs_dim, act_dim, lr=3e-4, clip=0.2, target_kl=0.015, iters=80, vf_coef=0.5, ent_coef=0.0, max_grad=0.5, hid=(256,256), device="cpu"):
        self.ac=PPOActorCritic(obs_dim,act_dim,hid).to(device)
        self.pi_opt=torch.optim.Adam(list(self.ac.pi.parameters())+[self.ac.log_std], lr=lr)
        self.vf_opt=torch.optim.Adam(self.ac.vf.parameters(), lr=lr)
        self.clip=clip; self.target_kl=target_kl; self.iters=iters
        self.vf_coef=vf_coef; self.ent_coef=ent_coef; self.max_grad=max_grad
        self.device=device
    def update(self, buf:PPOBuffer):
        data={k:v.to(self.device) for k,v in buf.get().items()}
        for _ in range(self.iters):
            logp, ent, v = self.ac.evaluate(data['obs'], data['act'])
            ratio=torch.exp(logp - data['logp'])
            surr1=ratio*data['adv']
            surr2=torch.clamp(ratio,1-self.clip,1+self.clip)*data['adv']
            pi_loss=-(torch.min(surr1,surr2)).mean() - self.ent_coef*ent.mean()
            v_loss=F.mse_loss(v, data['ret'])
            self.pi_opt.zero_grad(); pi_loss.backward()
            nn.utils.clip_grad_norm_(self.ac.parameters(), self.max_grad); self.pi_opt.step()
            self.vf_opt.zero_grad(); (self.vf_coef*v_loss).backward()
            nn.utils.clip_grad_norm_(self.ac.vf.parameters(), self.max_grad); self.vf_opt.step()
            with torch.no_grad():
                approx_kl=(data['logp']-logp).mean().item()
            if approx_kl>1.5*self.target_kl: break
    def act(self, o_np, deterministic=False):
        o=torch.as_tensor(o_np,dtype=torch.float32,device=self.device).unsqueeze(0)
        if deterministic:
            a=self.ac.act_mean(o).squeeze(0).cpu().detach().numpy(); v=self.ac.vf(o).item(); lp=None
        else:
            a,v,lp=self.ac.step(o); a=a.squeeze(0).cpu().detach().numpy(); v=v.item(); lp=lp.item()
        return a, v, lp

# ---------- TD3 ----------
class ActorTD3(nn.Module):
    def __init__(self, obs_dim, act_dim, hid=(256,256)):
        super().__init__(); self.net=mlp([obs_dim,*hid,act_dim], out_act=nn.Tanh)
    def forward(self, o): return self.net(o)

class CriticTwin(nn.Module):
    def __init__(self, obs_dim, act_dim, hid=(256,256)):
        super().__init__()
        self.q1=mlp([obs_dim+act_dim,*hid,1])
        self.q2=mlp([obs_dim+act_dim,*hid,1])
    def forward(self,o,a):
        x=torch.cat([o,a],-1); return self.q1(x), self.q2(x)

class TD3Agent:
    def __init__(self, obs_dim, act_dim, actor_lr=3e-4, critic_lr=3e-4, gamma=0.99, tau=0.005,
                 pol_noise=0.2, noise_clip=0.5, pol_delay=2, hid=(256,256), device="cpu"):
        self.actor=ActorTD3(obs_dim,act_dim,hid).to(device)
        self.actor_t=ActorTD3(obs_dim,act_dim,hid).to(device)
        self.critic=CriticTwin(obs_dim,act_dim,hid).to(device)
        self.critic_t=CriticTwin(obs_dim,act_dim,hid).to(device)
        self.actor_t.load_state_dict(self.actor.state_dict())
        self.critic_t.load_state_dict(self.critic.state_dict())
        self.pi_opt=torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.q_opt=torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.g=gamma; self.tau=tau; self.pol_noise=pol_noise
        self.noise_clip=noise_clip; self.pol_delay=pol_delay
        self.it=0; self.device=device
    def select_action(self, o_np, noise_std=0.1):
        o=torch.as_tensor(o_np,dtype=torch.float32,device=self.device).unsqueeze(0)
        with torch.no_grad(): a=self.actor(o).cpu().detach().numpy().squeeze(0)
        if noise_std>0: a=np.clip(a+np.random.normal(0,noise_std,a.shape), -1,1)
        return a
    def train_step(self, rb:ReplayBuffer, bs=256):
        self.it+=1; b=rb.sample(bs)
        obs,act,rew,obs2,done=[b[k].to(self.device) for k in ['obs','act','rew','obs2','done']]
        with torch.no_grad():
            noise=(torch.randn_like(act)*self.pol_noise).clamp(-self.noise_clip,self.noise_clip)
            next_a=(self.actor_t(obs2)+noise).clamp(-1,1)
            q1t,q2t=self.critic_t(obs2,next_a); q_t=torch.min(q1t,q2t)
            y=rew + (1-done)*self.g*q_t
        q1,q2=self.critic(obs,act)
        q_loss=F.mse_loss(q1,y)+F.mse_loss(q2,y)
        self.q_opt.zero_grad(); q_loss.backward(); self.q_opt.step()
        if self.it % self.pol_delay == 0:
            pi_loss=-self.critic.q1(torch.cat([obs,self.actor(obs)],-1)).mean()
            self.pi_opt.zero_grad(); pi_loss.backward(); self.pi_opt.step()
            with torch.no_grad():
                for p,pt in zip(self.actor.parameters(), self.actor_t.parameters()):
                    pt.data.mul_(1-self.tau).add_(self.tau*p.data)
                for p,pt in zip(self.critic.parameters(), self.critic_t.parameters()):
                    pt.data.mul_(1-self.tau).add_(self.tau*p.data)

# ---------- SAC ----------
class GaussianActor(nn.Module):
    def __init__(self, obs_dim, act_dim, hid=(256,256), log_std_min=-20, log_std_max=2):
        super().__init__(); self.net=mlp([obs_dim,*hid,2*act_dim])
        self.ad=act_dim; self.lmin=log_std_min; self.lmax=log_std_max
    def forward(self,o):
        x=self.net(o); mu,ls=x[...,:self.ad], x[...,self.ad:]
        ls=torch.clamp(ls,self.lmin,self.lmax); std=torch.exp(ls)
        dist=torch.distributions.Normal(mu,std)
        u=dist.rsample(); a=torch.tanh(u)
        logp=(dist.log_prob(u)-torch.log(1-a.pow(2)+1e-6)).sum(-1,True)
        return a, logp, torch.tanh(mu)
    def act(self,o,det=False):
        with torch.no_grad():
            return (self.forward(o)[2 if det else 0])

class QCritic(nn.Module):
    def __init__(self, obs_dim, act_dim, hid=(256,256)):
        super().__init__()
        self.q1=mlp([obs_dim+act_dim,*hid,1])
        self.q2=mlp([obs_dim+act_dim,*hid,1])
    def forward(self,o,a):
        x=torch.cat([o,a],-1); return self.q1(x), self.q2(x)

class SACAgent:
    def __init__(self, obs_dim, act_dim, g=0.99, tau=0.005, lr=3e-4, hid=(256,256),
                 target_entropy=None, auto_alpha=True, init_alpha=0.2, device="cpu"):
        self.actor=GaussianActor(obs_dim,act_dim,hid).to(device)
        self.q=QCritic(obs_dim,act_dim,hid).to(device)
        self.q_t=QCritic(obs_dim,act_dim,hid).to(device)
        self.q_t.load_state_dict(self.q.state_dict())
        self.pi_opt=torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.q_opt=torch.optim.Adam(self.q.parameters(), lr=lr)
        self.g=g; self.tau=tau; self.device=device
        self.auto=auto_alpha
        if target_entropy is None: target_entropy=-act_dim
        self.target_entropy=target_entropy
        if auto_alpha:
            self.log_alpha=torch.tensor(math.log(init_alpha), requires_grad=True, device=device)
            self.alpha_opt=torch.optim.Adam([self.log_alpha], lr=lr)
        else:
            self.alpha=init_alpha
    def select_action(self,o_np,eval_mode=False):
        o=torch.as_tensor(o_np,dtype=torch.float32,device=self.device).unsqueeze(0)
        a=self.actor.act(o,det=eval_mode).cpu().detach().numpy().squeeze(0)
        return a
    def train_step(self, rb:ReplayBuffer, bs=256):
        b=rb.sample(bs)
        obs,act,rew,obs2,done=[b[k].to(self.device) for k in ['obs','act','rew','obs2','done']]
        with torch.no_grad():
            a2,logp2,_=self.actor.forward(obs2)
            q1t,q2t=self.q_t(obs2,a2); alpha=self.log_alpha.exp() if self.auto else torch.tensor(self.alpha,device=self.device)
            y=rew+(1-done)*self.g*(torch.min(q1t,q2t)-alpha*logp2)
        q1,q2=self.q(obs,act)
        q_loss=F.mse_loss(q1,y)+F.mse_loss(q2,y)
        self.q_opt.zero_grad(); q_loss.backward(); self.q_opt.step()
        a,logp,_=self.actor.forward(obs)
        q1p,q2p=self.q(obs,a); qpi=torch.min(q1p,q2p)
        alpha=self.log_alpha.exp() if self.auto else torch.tensor(self.alpha,device=self.device)
        pi_loss=(alpha*logp - qpi).mean()
        self.pi_opt.zero_grad(); pi_loss.backward(); self.pi_opt.step()
        if self.auto:
            a_loss=-(self.log_alpha*(logp + self.target_entropy).detach()).mean()
            self.alpha_opt.zero_grad(); a_loss.backward(); self.alpha_opt.step()
        with torch.no_grad():
            for p,pt in zip(self.q.parameters(), self.q_t.parameters()):
                pt.data.mul_(1-self.tau).add_(self.tau*p.data)

# ---------- 通用小函数 ----------
def clip_action_to_space(a, space):
    low, high = space.low, space.high
    a = np.clip(a, -1, 1)  # 你的空间以[-1,1]为主，第三维ratio在env里再处理
    # 如需严格映射到 [low, high]：
    # return low + (a+1.0)*0.5*(high-low)
    return a



# ====== Lagrangian PPO: buffers ======
class PPOCostBuffer:
    """
    On-policy 缓冲区，支持 K 维 cost（K=1 时退化为单约束）。
    """
    def __init__(self, obs_dim, act_dim, size, cost_dim=1, gamma=0.99, lam=0.95, cost_gamma=None, cost_lam=None):
        self.obs  = np.zeros((size, obs_dim), np.float32)
        self.act  = np.zeros((size, act_dim), np.float32)
        self.rew  = np.zeros((size, 1), np.float32)
        self.val  = np.zeros((size, 1), np.float32)
        self.logp = np.zeros((size, 1), np.float32)
        self.cost = np.zeros((size, cost_dim), np.float32)          # K维 cost
        self.cval = np.zeros((size, cost_dim), np.float32)          # K 个 cost value 估计
        self.ret  = np.zeros((size, 1), np.float32)
        self.adv  = np.zeros((size, 1), np.float32)
        self.cret = np.zeros((size, cost_dim), np.float32)
        self.cadv = np.zeros((size, cost_dim), np.float32)

        self.gamma = gamma; self.lam = lam
        self.cost_gamma = cost_gamma if cost_gamma is not None else gamma
        self.cost_lam   = cost_lam   if cost_lam   is not None else lam

        self.ptr=0; self.path_start=0; self.max=size; self.cost_dim=cost_dim

    def store(self, o, a, r, v, lp, c_vec, c_v_vec):
        i=self.ptr; self.ptr+=1
        self.obs[i]=o; self.act[i]=a; self.rew[i]=r; self.val[i]=v; self.logp[i]=lp
        self.cost[i]=c_vec; self.cval[i]=c_v_vec

    def finish_path(self, last_val=0.0, last_cval=None):
        sl = slice(self.path_start, self.ptr)
        rews = np.append(self.rew[sl], [[0]], axis=0)
        vals = np.append(self.val[sl], [[last_val]], axis=0)
        deltas = rews[:-1] + self.gamma*vals[1:] - vals[:-1]
        self.adv[sl] = discount_cumsum(deltas, self.gamma * self.lam)
        self.ret[sl] = discount_cumsum(rews, self.gamma)[:-1]

        # 每个 cost 维度单独做 GAE
        if last_cval is None:
            last_cval = np.zeros((1, self.cost_dim), np.float32)
        cvals = np.append(self.cval[sl], last_cval, axis=0)
        for k in range(self.cost_dim):
            c = np.append(self.cost[sl, k:k+1], [[0]], axis=0)
            deltas_c = c[:-1] + self.cost_gamma*cvals[1:, k:k+1] - cvals[:-1, k:k+1]
            self.cadv[sl, k:k+1] = discount_cumsum(deltas_c, self.cost_gamma*self.cost_lam)
            self.cret[sl, k:k+1] = discount_cumsum(c, self.cost_gamma)[:-1]

        self.path_start = self.ptr

    def get(self):
        assert self.ptr == self.max
        self.ptr = 0; self.path_start = 0
        # 归一化 reward advantage
        adv = (self.adv - self.adv.mean()) / (self.adv.std() + 1e-8)
        # cost advantage 各维归一化（可选，通常不归一也可）
        cadv = self.cadv.copy()
        for k in range(self.cost_dim):
            m, s = cadv[:,k:k+1].mean(), cadv[:,k:k+1].std() + 1e-8
            cadv[:,k:k+1] = (cadv[:,k:k+1] - m) / s

        to_t = lambda x: torch.as_tensor(x, dtype=torch.float32)
        return dict(
            obs = to_t(self.obs), act=to_t(self.act),
            ret = to_t(self.ret), adv=to_t(adv),
            logp= to_t(self.logp),
            cost= to_t(self.cost), cval=to_t(self.cval),
            cret= to_t(self.cret), cadv=to_t(cadv)
        )

# ====== Lagrangian PPO: Actor-Critic ======
class LagrangianPPOActorCritic(nn.Module):
    """
    单约束：一个 cost critic。
    """
    def __init__(self, obs_dim, act_dim, hid=(256,256)):
        super().__init__()
        self.pi  = mlp([obs_dim,*hid,act_dim], act=nn.Tanh, out_act=nn.Identity)
        self.log_std = nn.Parameter(torch.zeros(act_dim)-0.5)
        self.vf  = mlp([obs_dim,*hid,1], act=nn.Tanh, out_act=nn.Identity)
        self.vc  = mlp([obs_dim,*hid,1], act=nn.Tanh, out_act=nn.Identity)  # cost value

    def step(self, obs):
        with torch.no_grad():
            self.log_std.clamp(-5.0, 2.0)
            mu = self.pi(obs); std=torch.exp(self.log_std)
            dist=torch.distributions.Normal(mu,std)
            u=dist.sample(); a=torch.tanh(u)
            logp=(dist.log_prob(torch.atanh(a)) - torch.log(1-a.pow(2)+1e-6)).sum(-1,True)
            v=self.vf(obs); c_v=self.vc(obs)
        return a, v, c_v, logp

    def act_mean(self, obs): return torch.tanh(self.pi(obs))

    def evaluate(self, obs, act):
        self.log_std.clamp(-5.0, 2.0)
        mu=self.pi(obs); std=torch.exp(self.log_std)
        dist=torch.distributions.Normal(mu,std)
        u=torch.atanh(torch.clamp(act,-0.999,0.999))
        logp=(dist.log_prob(u)-torch.log(1-act.pow(2)+1e-6)).sum(-1,True)
        ent=dist.entropy().sum(-1,True)
        v = self.vf(obs); c_v = self.vc(obs)
        return logp, ent, v, c_v

class LagrangianPPOMultiActorCritic(nn.Module):
    """
    多约束：K 个 cost critic。
    """
    def __init__(self, obs_dim, act_dim, cost_dim, hid=(256,256)):
        super().__init__()
        self.pi  = mlp([obs_dim,*hid,act_dim], act=nn.Tanh, out_act=nn.Identity)
        self.log_std = nn.Parameter(torch.zeros(act_dim)-0.5)
        self.vf  = mlp([obs_dim,*hid,1], act=nn.Tanh, out_act=nn.Identity)
        self.vc_list = nn.ModuleList([mlp([obs_dim,*hid,1], act=nn.Tanh, out_act=nn.Identity) for _ in range(cost_dim)])
        self.cost_dim = cost_dim

    def step(self, obs):
        with torch.no_grad():
            self.log_std.clamp(-5.0, 2.0)
            mu=self.pi(obs); std=torch.exp(self.log_std)
            dist=torch.distributions.Normal(mu,std)
            u=dist.sample(); a=torch.tanh(u)
            logp=(dist.log_prob(torch.atanh(a))-torch.log(1-a.pow(2)+1e-6)).sum(-1,True)
            v=self.vf(obs)
            c_v = torch.cat([net(obs) for net in self.vc_list], dim=-1)  # [B,K]
        return a, v, c_v, logp

    def act_mean(self, obs): return torch.tanh(self.pi(obs))

    def evaluate(self, obs, act):
        self.log_std.clamp(-5.0, 2.0)
        mu=self.pi(obs); std=torch.exp(self.log_std)
        dist=torch.distributions.Normal(mu,std)
        u=torch.atanh(torch.clamp(act,-0.999,0.999))
        logp=(dist.log_prob(u)-torch.log(1-act.pow(2)+1e-6)).sum(-1,True)
        ent=dist.entropy().sum(-1,True)
        v=self.vf(obs)
        c_v = torch.cat([net(obs) for net in self.vc_list], dim=-1)
        return logp, ent, v, c_v

# ====== Lagrangian PPO: Agents ======
class LagrangianPPOAgent:
    """
    单约束拉格朗日 PPO。
    """
    def __init__(self, obs_dim, act_dim, cost_limit: float,
                 lr=3e-4, clip=0.2, target_kl=0.015, iters=80,
                 vf_coef=0.5, ent_coef=0.0, max_grad=0.5, dual_lr=0.05, lambda_init=0.0, lambda_max=1e4,
                 hid=(256,256), device="cpu"):
        self.ac = LagrangianPPOActorCritic(obs_dim, act_dim, hid).to(device)
        self.pi_opt = torch.optim.Adam(list(self.ac.pi.parameters())+[self.ac.log_std], lr=lr)
        self.vf_opt = torch.optim.Adam(self.ac.vf.parameters(), lr=lr)
        self.vc_opt = torch.optim.Adam(self.ac.vc.parameters(), lr=lr)
        self.clip=clip; self.target_kl=target_kl; self.iters=iters
        self.vf_coef=vf_coef; self.ent_coef=ent_coef; self.max_grad=max_grad
        self.device=device
        # dual
        self.lam = torch.tensor(lambda_init, dtype=torch.float32, device=device)  # >=0
        self.dual_lr = dual_lr; self.lambda_max = lambda_max
        self.cost_limit = cost_limit

    def update(self, buf: PPOCostBuffer):
        data = {k:v.to(self.device) for k,v in buf.get().items()}
        # ------- policy / value updates -------
        for _ in range(self.iters):
            logp, ent, v, c_v = self.ac.evaluate(data['obs'], data['act'])
            ratio = torch.exp(logp - data['logp'])
            surr1 = ratio * data['adv']
            surr2 = torch.clamp(ratio, 1-self.clip, 1+self.clip) * data['adv']
            # cost surrogate（K=1）
            surr_cost = ratio * data['cadv'][:, :1]  # (B,1)
            pi_loss = -(torch.min(surr1, surr2) - self.lam * surr_cost).mean() - self.ent_coef*ent.mean()

            v_loss = F.mse_loss(v, data['ret'])
            c_loss = F.mse_loss(c_v, data['cret'][:, :1])

            self.pi_opt.zero_grad(); pi_loss.backward()
            nn.utils.clip_grad_norm_(self.ac.parameters(), self.max_grad); self.pi_opt.step()

            self.vf_opt.zero_grad(); (self.vf_coef*v_loss).backward()
            nn.utils.clip_grad_norm_(self.ac.vf.parameters(), self.max_grad); self.vf_opt.step()

            self.vc_opt.zero_grad(); c_loss.backward(); self.vc_opt.step()

            with torch.no_grad():
                approx_kl = (data['logp'] - logp).mean().item()
            if approx_kl > 1.5*self.target_kl: break

        # ------- dual ascent -------
        with torch.no_grad():
            Jc = data['cret'][:, :1].mean().item()
        new_lambda = (self.lam + self.dual_lr * (Jc - self.cost_limit)).clamp(min=0.0, max=self.lambda_max)
        self.lam.data.copy_(new_lambda)

        return dict(Jc=Jc, lambda_value=float(self.lam.item()))

    def act(self, o_np, deterministic=False):
        o=torch.as_tensor(o_np,dtype=torch.float32,device=self.device).unsqueeze(0)
        if deterministic: a=self.ac.act_mean(o).squeeze(0).cpu().detach().numpy(); v=self.ac.vf(o).item(); c_v=self.ac.vc(o).item(); lp=None
        else:
            a,v,c_v,lp=self.ac.step(o); a=a.squeeze(0).cpu().detach().numpy(); v=v.item(); c_v=c_v.item(); lp=lp.item()
        return a, v, c_v, lp

class LagrangianPPOMultiAgent:
    """
    多约束（K个）拉格朗日 PPO。
    """
    def __init__(self, obs_dim, act_dim, cost_limits: np.ndarray,
                 lr=3e-4, clip=0.2, target_kl=0.015, iters=80,
                 vf_coef=0.5, ent_coef=0.0, max_grad=0.5, dual_lr=0.05, lambda_init=0.0, lambda_max=1e4,
                 hid=(256,256), device="cpu"):
        K = len(cost_limits)
        self.ac = LagrangianPPOMultiActorCritic(obs_dim, act_dim, K, hid).to(device)
        self.pi_opt = torch.optim.Adam(list(self.ac.pi.parameters())+[self.ac.log_std], lr=lr)
        self.vf_opt = torch.optim.Adam(self.ac.vf.parameters(), lr=lr)
        self.vc_opts= [torch.optim.Adam(vc.parameters(), lr=lr) for vc in self.ac.vc_list]
        self.clip=clip; self.target_kl=target_kl; self.iters=iters
        self.vf_coef=vf_coef; self.ent_coef=ent_coef; self.max_grad=max_grad
        self.device=device
        # dual vector
        self.lam = torch.ones(K, dtype=torch.float32, device=device)*lambda_init
        self.dual_lr = dual_lr; self.lambda_max = lambda_max
        
        self.cost_limits = torch.as_tensor(cost_limits, dtype=torch.float32, device=device)

    def update(self, buf: PPOCostBuffer):
        data = {k: v.to(self.device) for k, v in buf.get().items()}
        K = self.lam.numel()

        for _ in range(self.iters):
            # 重新前向：新策略的 logp，用于 ratio；以及 v, c_v
            logp, ent, v, c_v = self.ac.evaluate(data['obs'], data['act'])  # c_v: [B, K]
            ratio = torch.exp(logp - data['logp'])  # data['logp'] 是旧策略的 logp（采样时存的）
            surr1 = ratio * data['adv']
            surr2 = torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * data['adv']

            # Σ_i λ_i * E[ ratio * cadv_i ]
            surr_cost_all = []
            for k in range(K):
                cadv_k = data['cadv'][:, k:k+1]  # (B,1)
                surr_cost_all.append((ratio * cadv_k).mean())
            surr_cost_all = torch.stack(surr_cost_all)  # (K,)

            pi_loss = -(torch.min(surr1, surr2).mean() - (self.lam * surr_cost_all).sum()) \
                    - self.ent_coef * ent.mean()
            v_loss = F.mse_loss(v, data['ret'])

            # ---- 多 head 的 cost critic：一次性合并反传，避免重复 backward ----
            for opt in self.vc_opts:
                opt.zero_grad()
            c_losses = [F.mse_loss(c_v[:, k:k+1], data['cret'][:, k:k+1]) for k in range(K)]
            c_loss = torch.stack(c_losses).sum()

            # 先更新策略与价值
            self.pi_opt.zero_grad(); pi_loss.backward()
            nn.utils.clip_grad_norm_(self.ac.parameters(), self.max_grad); self.pi_opt.step()

            self.vf_opt.zero_grad(); (self.vf_coef * v_loss).backward()
            nn.utils.clip_grad_norm_(self.ac.vf.parameters(), self.max_grad); self.vf_opt.step()

            # 再一次反传 cost critics，并分别 step
            c_loss.backward()
            for vc, opt in zip(self.ac.vc_list, self.vc_opts):
                nn.utils.clip_grad_norm_(vc.parameters(), self.max_grad)
                opt.step()

            with torch.no_grad():
                approx_kl = (data['logp'] - logp).mean().item()
            if approx_kl > 1.5 * self.target_kl:
                break

        # ------- dual ascent -------
        with torch.no_grad():
            Jc = data['cret'].mean(0)  # (K,)
        new_lam = (self.lam + self.dual_lr * (Jc - self.cost_limits)).clamp(min=0.0, max=self.lambda_max)
        self.lam.data.copy_(new_lam)
        return dict(Jc=Jc.detach().cpu().numpy(), lambdas=self.lam.detach().cpu().numpy())

    def act(self, o_np, deterministic=False):
        o=torch.as_tensor(o_np,dtype=torch.float32,device=self.device).unsqueeze(0)
        if deterministic:
            a=self.ac.act_mean(o).squeeze(0).detach().cpu().numpy()
        else:
            a,_,_,_ = self.ac.step(o); a=a.squeeze(0).cpu().detach().numpy()
        return a