import torch
import torch.nn.functional as F
from Network import PolicyNet, ValueNet
from util import compute_advantage


class VDPPO:
    """Value‑Decomposed Proximal Policy Optimization (VD‑PPO)
    A discrete‑state variant that follows the original VDAC/VD‑PPO idea of
    *per‑agent* value functions whose *sum* approximates the global critic.

    Compared with MAPPO, each agent keeps its own value network.  During
    training we back‑propagate through the *sum* of those values so that
    the decomposition constraint \u2211_i V_i(s_i) ≈ V_tot(s) is enforced
    implicitly.
    """

    def __init__(
        self,
        agent_num: int,
        state_dim_list: list[int],
        hidden_dim: int,
        action_num_list: list[int],
        actor_lr: float,
        critic_lr: float,
        epochs: int,
        eps: float,
        gamma: float,
        device: torch.device,
        sample_size: int=1,
        entropy_soft=None,
    ) -> None:
        self.agent_num = agent_num
        self.device = device
        self.gamma = gamma
        self.epochs = epochs
        self.eps = eps
        self.entropy_coef = 0.01

        # ---  Actors & optimisers  ---------------------------------------------------
        self.actors = [
            PolicyNet(state_dim_list[i], hidden_dim, action_num_list[i]).to(device)
            for i in range(agent_num)
        ]
        self.actor_opts = [
            torch.optim.AdamW(actor.parameters(), lr=actor_lr) for actor in self.actors
        ]

        # ---  Per‑agent critics & optimisers  ----------------------------------------
        self.critics = [
            ValueNet(state_dim_list[i], hidden_dim).to(device) for i in range(agent_num)
        ]
        self.critic_opts = [
            torch.optim.AdamW(critic.parameters(), lr=critic_lr) for critic in self.critics
        ]

    # ---------------------------------------------------------------------  rollout
    def take_action(self, state_list):
        """Vectorised act: *state_list* is a list[Tensor[num_envs, state_dim]]."""
        actions = []
        with torch.no_grad():
            for i in range(self.agent_num):
                probs = self.actors[i](state_list[i].to(self.device))
                dist = torch.distributions.Categorical(probs)
                actions.append(dist.sample())
        return actions

    # ---------------------------------------------------------------------  helpers
    def _joint_value(self, per_agent_states):
        """Return (V_tot, list(V_i)).  States must already be on *self.device*."""
        v_is = [self.critics[i](per_agent_states[i]) for i in range(self.agent_num)]
        v_tot = torch.stack(v_is, dim=0).sum(dim=0)  # [B,1]
        return v_tot, v_is

    # ---------------------------------------------------------------------  update
    def update(self, transition_dict: dict):
        """Update nets with a trajectory collected by *rollout*.

        Required keys in *transition_dict* are identical to those produced by
        the provided MAPPO rollout helper (states, actions, next_states, rewards,
        dones) after they have been *torch.cat*‑ed and moved to \u2018device\u2019.
        Shapes (T = episode length):
          states        [T, n_agents, s_dim]
          actions       [T, n_agents, 1 ]  (long)
          next_states   same as states
          rewards       [T,               1 ]  (shared team reward)
          dones         [T,               1 ]  (0/1)
        """
        joint_states = transition_dict["states"]  # [T, n_agents, s_dim]
        next_joint_states = transition_dict["next_states"]
        rewards = transition_dict["rewards"].view(-1,1)  # shared -> [T,1]
        dones = transition_dict["dones"].view(-1,1).float()
        actions = transition_dict["actions"].long()

        T = joint_states.shape[0]

        # Flatten time dimension so that each step is a sample in the PPO minibatch.
        s = joint_states.reshape(T * self.agent_num, -1)  # unsuitable ‑ need per agent
        # We instead keep agent dimension for easier per‑agent slicing.

        # ----------  Build per‑agent state tensors  ---------------------------------
        per_agent_states = [joint_states[:, i, :].reshape(T, -1).to(self.device) for i in range(self.agent_num)]
        per_agent_next_states = [
            next_joint_states[:, i, :].reshape(T, -1).to(self.device) for i in range(self.agent_num)
        ]

        rewards = rewards.to(self.device)
        dones = dones.to(self.device)

        # ----------  TD target & advantage (global)  --------------------------------
        with torch.no_grad():
            v_next, _ = self._joint_value(per_agent_next_states)  # [40, 1]
            td_target = rewards + self.gamma * v_next * (1.0 - dones)  # [T,1]

        v, _ = self._joint_value(per_agent_states)  # [T,1]
        td_delta = td_target - v
        advantage = compute_advantage(self.gamma, 0.95, td_delta.cpu(), dones.cpu()).to(self.device)

        # ----------  Actor updates  --------------------------------------------------
        for i in range(self.agent_num):
            s_i = per_agent_states[i]  # [T, s_dim_i]
            a_i = actions[:, i].view(-1,1).to(self.device)  # [T,1]
            old_log_pi = (
                torch.log(self.actors[i](s_i).gather(1, a_i)).detach()
            )  # [T,1]

            for _ in range(self.epochs):
                pi = self.actors[i](s_i)
                log_pi = torch.log(pi.gather(1, a_i))
                ratio = torch.exp(log_pi - old_log_pi)

                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage
                actor_loss = -(torch.min(surr1, surr2)).mean()

                # optional entropy bonus
                entropy = -(pi * torch.log(pi + 1e-8)).sum(dim=1).mean()
                actor_loss -= self.entropy_coef * entropy

                self.actor_opts[i].zero_grad()
                actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actors[i].parameters(), 40.0)
                self.actor_opts[i].step()

        # ----------  Critic updates (joint loss)  ------------------------------------
        # We back‑prop once through the *sum* of per‑agent values so gradients flow
        # to every individual critic and respect the VDN decomposition.
        for _ in range(self.epochs):
            v_tot_pred, v_is_pred = self._joint_value(per_agent_states)
            critic_loss = F.mse_loss(v_tot_pred, td_target.detach())

            for opt in self.critic_opts:
                opt.zero_grad()
            critic_loss.backward()
            # Gradient clipping per‑agent
            for i in range(self.agent_num):
                torch.nn.utils.clip_grad_norm_(self.critics[i].parameters(), 40.0)
                self.critic_opts[i].step()
