import torch
import torch.nn.functional as F
import numpy as np
from Network import PolicyNetContinuous, ValueNet


class HAPPOContinuous:
    def __init__(self, agent_num, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 lmbda, epochs, eps, gamma, sample_size, bound, device):
        self.agent_num = agent_num
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.lmbda = lmbda
        self.epochs = epochs
        self.eps = eps
        self.gamma = gamma
        self.device = device

        self.actors = [PolicyNetContinuous(state_dim, hidden_dim, action_dim, bound).to(device)
                       for _ in range(agent_num)]

        self.critic = ValueNet(state_dim * agent_num, hidden_dim).to(device)

        self.actor_optimizers = [torch.optim.Adam(actor.parameters(), lr=actor_lr, weight_decay=1e-5)
                                 for actor in self.actors]
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr, weight_decay=1e-5)

        self.action_aggregation = 'prod'

    def take_action(self, obs, eval=False):
        obs_torch = torch.tensor(obs, dtype=torch.float).to(self.device)
        with torch.no_grad():
            if eval:
                mu = torch.zeros((self.agent_num, self.action_dim), device=self.device)
                for i in range(self.agent_num):
                    mu[i], _ = self.actors[i](obs_torch[i])
                return mu.cpu().numpy()
            else:
                actions = []
                for i in range(self.agent_num):
                    mu, sigma = self.actors[i](obs_torch[i])
                    dist = torch.distributions.Normal(mu, sigma)
                    action = dist.sample()
                    actions.append(action)
                return torch.stack(actions).cpu().numpy()

    def take_action_async(self, obs, eval=False):
        obs_torch = torch.tensor(obs, dtype=torch.float)
        if eval:
            return self._deterministic_action(obs_torch)
        return self._stochastic_action(obs_torch)

    def _stochastic_action(self, obs):
        mu = torch.zeros(obs.shape[0], self.agent_num, self.action_dim, device=self.device)
        sigma = torch.zeros_like(mu)
        for i in range(self.agent_num):
            mu_i, sigma_i = self.actors[i](obs[:, i, :])
            mu[:, i, :] = mu_i
            sigma[:, i, :] = sigma_i
        dist = torch.distributions.Normal(mu, sigma)
        return dist.sample().cpu().numpy()

    def _deterministic_action(self, obs):
        mu = torch.zeros(obs.shape[0], self.agent_num, self.action_dim, device=self.device)
        for i in range(self.agent_num):
            mu_i, _ = self.actors[i](obs[:, i, :])
            mu[:, i, :] = mu_i
        return mu.cpu().numpy()

    def update(self, transition_dict):
        states = transition_dict['states']
        actions = transition_dict['actions']
        rewards = transition_dict['rewards'].view(-1, 1)
        dones = transition_dict['dones'].float().view(-1, 1)
        next_states = transition_dict['next_states']

        advantages = self._compute_gae(states, rewards, dones)

        agent_order = np.random.permutation(self.agent_num)
        factor = torch.ones_like(advantages)

        for idx in agent_order:
            self._update_actor(idx, states, actions, advantages, factor)
            self._update_factor(idx, states, actions, factor)

        self._update_critic(states, rewards, next_states, dones)

    def _compute_gae(self, states, rewards, dones):
        with torch.no_grad():
            state_values = self.critic(states.view(len(states), -1))
            next_values = torch.cat([state_values[1:], state_values[-1:]], dim=0)
            td_deltas = rewards + self.gamma * next_values * (1 - dones) - state_values

            advantages = torch.zeros_like(td_deltas)
            gae = 0.0
            for t in reversed(range(len(td_deltas))):
                gae = td_deltas[t] + self.gamma * self.lmbda * gae * (1 - dones[t])
                advantages[t] = gae
        return advantages

    def _update_actor(self, agent_idx, states, actions, advantages, factor):
        actor = self.actors[agent_idx]
        optimizer = self.actor_optimizers[agent_idx]

        obs = states[:, agent_idx, :]
        act = actions[:, agent_idx, :]

        with torch.no_grad():
            old_mu, old_sigma = actor(obs)
            old_dist = torch.distributions.Normal(old_mu, old_sigma)
            old_log_probs = old_dist.log_prob(act).sum(-1, keepdim=True)

        for _ in range(self.epochs):
            mu, sigma = actor(obs)
            dist = torch.distributions.Normal(mu, sigma)
            log_probs = dist.log_prob(act).sum(-1, keepdim=True)

            ratio = (log_probs - old_log_probs).exp()
            surr1 = ratio * advantages * factor
            surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantages * factor
            actor_loss = -torch.min(surr1, surr2).mean()

            entropy = dist.entropy().mean()
            actor_loss -= 0.001 * entropy

            optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), 40.0)
            optimizer.step()

    def _update_factor(self, agent_idx, states, actions, factor):
        with torch.no_grad():
            obs = states[:, agent_idx, :]
            act = actions[:, agent_idx, :]

            mu_new, sigma_new = self.actors[agent_idx](obs)
            dist_new = torch.distributions.Normal(mu_new, sigma_new)
            new_log_probs = dist_new.log_prob(act).sum(-1, keepdim=True)

            mu_old, sigma_old = self.actors[agent_idx](obs)
            dist_old = torch.distributions.Normal(mu_old, sigma_old)
            old_log_probs = dist_old.log_prob(act).sum(-1, keepdim=True)

            factor *= (new_log_probs - old_log_probs).exp()

    def _update_critic(self, states, rewards, next_states, dones):
        for _ in range(self.epochs):
            current_v = self.critic(states.view(len(states), -1))
            with torch.no_grad():
                next_v = self.critic(next_states.view(len(next_states), -1))

            target_v = rewards + self.gamma * next_v * (1 - dones)
            critic_loss = F.mse_loss(current_v, target_v)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 40.0)
            self.critic_optimizer.step()
