import torch
import torch.nn.functional as F
import numpy as np
from Network import PolicyNet, ValueNet
from util import compute_advantage


class HAPPO:
    """
    Discrete HAPPO implementation:
        - Agents are updated one by one in a random order.
        - Each agent uses PPO-Clip with importance weighting via a multiplicative factor.
        - Centralized critic takes concatenated states as input.
    """
    def __init__(self, agent_num, state_dim_list, hidden_dim, action_num_list,
                 actor_lr, critic_lr, epochs, eps, gamma,
                 device, sample_size, entropy_soft=False):
        self.agent_num   = agent_num
        self.device      = device
        self.epochs      = epochs
        self.eps         = eps
        self.gamma       = gamma
        self.sample_size = sample_size
        self.entropy_on  = entropy_soft

        self.lmbda       = 0.95
        self.entropy_c   = 0.01

        self.actors = [PolicyNet(state_dim_list[i], hidden_dim,
                                 action_num_list[i]).to(device)
                       for i in range(agent_num)]
        self.critic = ValueNet(sum(state_dim_list), hidden_dim).to(device)

        self.actor_optimizers = [torch.optim.AdamW(a.parameters(), lr=actor_lr)
                                 for a in self.actors]
        self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(), lr=critic_lr)

    # =====================================================================
    # Action sampling (non-deterministic or greedy)
    # =====================================================================
    @torch.no_grad()
    def take_action(self, state_list, deterministic=False):
        actions = []
        for i in range(self.agent_num):
            logits = self.actors[i](state_list[i].to(self.device))
            dist   = torch.distributions.Categorical(logits)
            a      = torch.argmax(logits, dim=-1) if deterministic else dist.sample()
            actions.append(a.cpu())
        return actions

    # =====================================================================
    # Training update
    # =====================================================================
    def update(self, transition_dict):
        s   = transition_dict['states'].to(self.device)
        a   = transition_dict['actions'].long().to(self.device)
        r   = transition_dict['rewards'].to(self.device).view(-1, 1)
        d   = transition_dict['dones'].float().to(self.device).view(-1, 1)
        ns  = transition_dict['next_states'].to(self.device)

        T   = s.shape[0]
        gs  = s.view(T, -1)
        ngs = ns.view(T, -1)

        with torch.no_grad():
            v_s  = self.critic(gs)
            v_ns = self.critic(ngs)
            td   = r + self.gamma * v_ns * (1 - d) - v_s
            adv  = compute_advantage(self.gamma, self.lmbda, td, d).to(self.device)

        factor = torch.ones_like(adv)

        for idx in np.random.permutation(self.agent_num):
            self._update_actor(idx, s, a, adv, factor)
            self._update_factor(idx, s, a, factor)

        self._update_critic(gs, r, ngs, d)

    # =====================================================================
    # Actor update for one agent
    # =====================================================================
    def _update_actor(self, i, s, a, adv, factor):
        actor, opt = self.actors[i], self.actor_optimizers[i]
        obs_i      = s[:, i, :]
        act_i      = a[:, i]

        with torch.no_grad():
            old_logits = actor(obs_i)
            old_logp   = torch.distributions.Categorical(old_logits)\
                          .log_prob(act_i).view(-1, 1)

        for _ in range(self.epochs):
            logits = actor(obs_i)
            dist   = torch.distributions.Categorical(logits)
            logp   = dist.log_prob(act_i).view(-1, 1)

            ratio  = torch.exp(logp - old_logp)
            s1     = ratio * adv * factor
            s2     = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * adv * factor
            loss   = -torch.min(s1, s2).mean()

            if self.entropy_on:
                loss -= self.entropy_c * dist.entropy().mean()

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), 40.0)
            opt.step()

    # =====================================================================
    # Update importance weight (factor) for the next agents
    # =====================================================================
    @torch.no_grad()
    def _update_factor(self, i, s, a, factor):
        obs_i = s[:, i, :]
        act_i = a[:, i]

        logits_new = self.actors[i](obs_i)
        dist_new   = torch.distributions.Categorical(logits_new)
        logp_new   = dist_new.log_prob(act_i).view(-1, 1)

        logits_old = logits_new.detach()
        logp_old   = torch.distributions.Categorical(logits_old)\
                     .log_prob(act_i).view(-1, 1)

        factor *= torch.exp(logp_new - logp_old)

    # =====================================================================
    # Critic update
    # =====================================================================
    def _update_critic(self, gs, r, ngs, d):
        for _ in range(self.epochs):
            v_s = self.critic(gs)
            with torch.no_grad():
                target = r + self.gamma * self.critic(ngs) * (1 - d)
            loss = F.mse_loss(v_s, target)
            self.critic_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 40.0)
            self.critic_optimizer.step()
