import torch
import torch.nn.functional as F
from Network import PolicyNet, ValueNet        # 匿名化处理：Network → model
from util import compute_advantage          # 匿名化处理：util → utils


class VDPPO:
    """
    Value-Decomposed Proximal Policy Optimization (VD-PPO).

    Each agent maintains its own value function, and their sum approximates
    the global critic. Gradients are backpropagated through the sum of per-agent
    values, implicitly enforcing the decomposition constraint.
    """

    def __init__(
        self,
        agent_num: int,
        state_dim_list: list[int],
        hidden_dim: int,
        action_num_list: list[int],
        actor_lr: float,
        critic_lr: float,
        epochs: int,
        eps: float,
        gamma: float,
        device: torch.device,
        sample_size: int = 1,
        entropy_soft=None,
    ) -> None:
        self.agent_num = agent_num
        self.device = device
        self.gamma = gamma
        self.epochs = epochs
        self.eps = eps
        self.entropy_coef = 0.01

        self.actors = [
            PolicyNet(state_dim_list[i], hidden_dim, action_num_list[i]).to(device)
            for i in range(agent_num)
        ]
        self.actor_opts = [
            torch.optim.AdamW(actor.parameters(), lr=actor_lr) for actor in self.actors
        ]

        self.critics = [
            ValueNet(state_dim_list[i], hidden_dim).to(device) for i in range(agent_num)
        ]
        self.critic_opts = [
            torch.optim.AdamW(critic.parameters(), lr=critic_lr) for critic in self.critics
        ]

    def take_action(self, state_list):
        actions = []
        with torch.no_grad():
            for i in range(self.agent_num):
                probs = self.actors[i](state_list[i].to(self.device))
                dist = torch.distributions.Categorical(probs)
                actions.append(dist.sample())
        return actions

    def _joint_value(self, per_agent_states):
        v_is = [self.critics[i](per_agent_states[i]) for i in range(self.agent_num)]
        v_tot = torch.stack(v_is, dim=0).sum(dim=0)
        return v_tot, v_is

    def update(self, transition_dict: dict):
        joint_states = transition_dict["states"]
        next_joint_states = transition_dict["next_states"]
        rewards = transition_dict["rewards"].view(-1, 1)
        dones = transition_dict["dones"].view(-1, 1).float()
        actions = transition_dict["actions"].long()
        T = joint_states.shape[0]

        per_agent_states = [joint_states[:, i, :].reshape(T, -1).to(self.device) for i in range(self.agent_num)]
        per_agent_next_states = [next_joint_states[:, i, :].reshape(T, -1).to(self.device) for i in range(self.agent_num)]

        rewards = rewards.to(self.device)
        dones = dones.to(self.device)

        with torch.no_grad():
            v_next, _ = self._joint_value(per_agent_next_states)
            td_target = rewards + self.gamma * v_next * (1.0 - dones)

        v, _ = self._joint_value(per_agent_states)
        td_delta = td_target - v
        advantage = compute_advantage(self.gamma, 0.95, td_delta.cpu(), dones.cpu()).to(self.device)

        for i in range(self.agent_num):
            s_i = per_agent_states[i]
            a_i = actions[:, i].view(-1, 1).to(self.device)
            old_log_pi = torch.log(self.actors[i](s_i).gather(1, a_i)).detach()

            for _ in range(self.epochs):
                pi = self.actors[i](s_i)
                log_pi = torch.log(pi.gather(1, a_i))
                ratio = torch.exp(log_pi - old_log_pi)
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage
                actor_loss = -(torch.min(surr1, surr2)).mean()

                entropy = -(pi * torch.log(pi + 1e-8)).sum(dim=1).mean()
                actor_loss -= self.entropy_coef * entropy

                self.actor_opts[i].zero_grad()
                actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actors[i].parameters(), 40.0)
                self.actor_opts[i].step()

        for _ in range(self.epochs):
            v_tot_pred, _ = self._joint_value(per_agent_states)
            critic_loss = F.mse_loss(v_tot_pred, td_target.detach())

            for opt in self.critic_opts:
                opt.zero_grad()
            critic_loss.backward()
            for i in range(self.agent_num):
                torch.nn.utils.clip_grad_norm_(self.critics[i].parameters(), 40.0)
                self.critic_opts[i].step()
