import torch
import torch.nn.functional as F
from Network import PolicyNet, ValueNet
from util import compute_advantage

__all__ = ["QMIXPPO"]


class QMIXMixer(torch.nn.Module):
    """Monotonic mixing network mapping per-agent values
    to a joint value conditioned on the global state.
    Weights are forced positive via absolute value to preserve
    the monotonicity constraint ∂V_tot / ∂V_i ≥ 0.
    """

    def __init__(
        self,
        n_agents: int,
        state_dim: int,
        embed_dim: int = 32,
        hypernet_layers: int = 2,
        hypernet_embed: int = 64,
    ) -> None:
        super().__init__()
        self.n_agents = n_agents
        self.state_dim = state_dim
        self.embed_dim = embed_dim

        if hypernet_layers == 1:
            self.hyper_w1 = torch.nn.Linear(state_dim, embed_dim * n_agents)
            self.hyper_w_final = torch.nn.Linear(state_dim, embed_dim)
        elif hypernet_layers == 2:
            self.hyper_w1 = torch.nn.Sequential(
                torch.nn.Linear(state_dim, hypernet_embed),
                torch.nn.ReLU(),
                torch.nn.Linear(hypernet_embed, embed_dim * n_agents),
            )
            self.hyper_w_final = torch.nn.Sequential(
                torch.nn.Linear(state_dim, hypernet_embed),
                torch.nn.ReLU(),
                torch.nn.Linear(hypernet_embed, embed_dim),
            )
        else:
            raise ValueError("hypernet_layers > 2 not supported")

        self.hyper_b1 = torch.nn.Linear(state_dim, embed_dim)
        self.V = torch.nn.Sequential(
            torch.nn.Linear(state_dim, embed_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(embed_dim, 1),
        )

    def forward(self, agent_values: torch.Tensor, state: torch.Tensor):
        bs = agent_values.size(0)
        agent_qs = agent_values.view(bs, 1, self.n_agents)

        w1 = torch.abs(self.hyper_w1(state)).view(bs, self.n_agents, self.embed_dim)
        b1 = self.hyper_b1(state).view(bs, 1, self.embed_dim)
        hidden = torch.relu(torch.bmm(agent_qs, w1) + b1)

        w_final = torch.abs(self.hyper_w_final(state)).view(bs, self.embed_dim, 1)
        v = self.V(state).view(bs, 1, 1)
        y = torch.bmm(hidden, w_final) + v
        return y.view(bs, 1)


class QMIXPPO:
    def __init__(
        self,
        agent_num: int,
        state_dim_list: list[int],
        hidden_dim: int,
        action_num_list: list[int],
        actor_lr: float,
        critic_lr: float,
        epochs: int,
        eps: float,
        gamma: float,
        device: torch.device,
        mixing_embed_dim: int = 32,
        hypernet_layers: int = 2,
        hypernet_embed: int = 64,
        sample_size=None,
        entropy_soft: bool = False,
        entropy_coef: float = 0.01,
    ) -> None:
        self.agent_num = agent_num
        self.device = device
        self.gamma = gamma
        self.epochs = epochs
        self.eps = eps
        self.entropy_coef = entropy_coef

        self.actors = [
            PolicyNet(state_dim_list[i], hidden_dim, action_num_list[i]).to(device)
            for i in range(agent_num)
        ]
        self.actor_opts = [
            torch.optim.AdamW(actor.parameters(), lr=actor_lr) for actor in self.actors
        ]

        self.critics = [
            ValueNet(state_dim_list[i], hidden_dim).to(device) for i in range(agent_num)
        ]
        self.critic_opts = [
            torch.optim.AdamW(critic.parameters(), lr=critic_lr) for critic in self.critics
        ]

        self.state_dim = sum(state_dim_list)
        self.mixer = QMIXMixer(
            n_agents=agent_num,
            state_dim=self.state_dim,
            embed_dim=mixing_embed_dim,
            hypernet_layers=hypernet_layers,
            hypernet_embed=hypernet_embed,
        ).to(device)
        self.mixer_opt = torch.optim.AdamW(self.mixer.parameters(), lr=critic_lr)

    def take_action(self, state_list):
        actions = []
        with torch.no_grad():
            for i in range(self.agent_num):
                probs = self.actors[i](state_list[i].to(self.device))
                dist = torch.distributions.Categorical(probs)
                actions.append(dist.sample())
        return actions

    def _joint_value(self, per_agent_states, global_state):
        v_is = [critic(s_i) for critic, s_i in zip(self.critics, per_agent_states)]
        v_cat = torch.cat(v_is, dim=1)
        v_tot = self.mixer(v_cat, global_state)
        return v_tot, v_is

    def update(self, transition_dict: dict):
        joint_states = transition_dict["states"]
        next_joint_states = transition_dict["next_states"]
        rewards = transition_dict["rewards"].view(-1, 1)
        dones = transition_dict["dones"].view(-1, 1).float()
        actions = transition_dict["actions"].long()

        T = joint_states.shape[0]
        rewards = rewards.to(self.device)
        dones = dones.to(self.device)

        per_agent_states = [joint_states[:, i, :].reshape(T, -1).to(self.device) for i in range(self.agent_num)]
        per_agent_next_states = [next_joint_states[:, i, :].reshape(T, -1).to(self.device) for i in range(self.agent_num)]

        global_state = joint_states.reshape(T, -1).to(self.device)
        next_global_state = next_joint_states.reshape(T, -1).to(self.device)

        with torch.no_grad():
            v_next, _ = self._joint_value(per_agent_next_states, next_global_state)
            td_target = rewards + self.gamma * v_next * (1.0 - dones)

        v, _ = self._joint_value(per_agent_states, global_state)
        td_delta = td_target - v
        advantage = compute_advantage(self.gamma, 0.95, td_delta.cpu(), dones.cpu()).to(self.device)

        for i in range(self.agent_num):
            s_i = per_agent_states[i]
            a_i = actions[:, i].view(-1, 1).to(self.device)
            old_log_pi = torch.log(self.actors[i](s_i).gather(1, a_i)).detach()
            for _ in range(self.epochs):
                pi = self.actors[i](s_i)
                log_pi = torch.log(pi.gather(1, a_i))
                ratio = torch.exp(log_pi - old_log_pi)
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage
                actor_loss = -(torch.min(surr1, surr2)).mean()

                entropy = -(pi * torch.log(pi + 1e-8)).sum(dim=1).mean()
                actor_loss -= self.entropy_coef * entropy

                self.actor_opts[i].zero_grad()
                actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actors[i].parameters(), 40.0)
                self.actor_opts[i].step()

        for _ in range(self.epochs):
            v_tot_pred, _ = self._joint_value(per_agent_states, global_state)
            critic_loss = F.mse_loss(v_tot_pred, td_target.detach())
            for opt in self.critic_opts:
                opt.zero_grad()
            self.mixer_opt.zero_grad()
            critic_loss.backward()
            for i in range(self.agent_num):
                torch.nn.utils.clip_grad_norm_(self.critics[i].parameters(), 40.0)
            torch.nn.utils.clip_grad_norm_(self.mixer.parameters(), 40.0)
            for opt in self.critic_opts:
                opt.step()
            self.mixer_opt.step()
