import torch
import torch.nn.functional as F
from Network import PolicyNet, ValueNet
from util import compute_advantage


class MAAC:
    def __init__(self, agent_num, state_dim_list, hidden_dim, action_num_list, actor_lr, critic_lr, epochs, eps,
                 gamma, device, sample_size, entropy_soft=False):
        self.actors = [PolicyNet(state_dim_list[i], hidden_dim, action_num_list[i]).to(device)
                       for i in range(agent_num)]
        self.critic = ValueNet(sum(state_dim_list), hidden_dim).to(device)
        self.actor_optimizers = [torch.optim.AdamW(actor.parameters(), lr=actor_lr)
                                 for actor in self.actors]
        self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma
        self.device = device
        self.agent_num = agent_num
        self.sample_size = sample_size
        self.entropy_soft = entropy_soft

    def take_action(self, state_list):
        action_list = []
        with torch.no_grad():
            for i in range(self.agent_num):
                state = state_list[i].to(self.device)
                probs = self.actors[i](state)
                dist = torch.distributions.Categorical(probs)
                action = dist.sample()
                action_list.append(action)
        return action_list

    def update(self, transition_dict):
        joint_states = transition_dict['states']  # [T, N, s_i]
        episode_length = joint_states.shape[0]
        global_states = joint_states.view(episode_length, -1)

        actions = [transition_dict['actions'][:, i].view(-1, 1) for i in range(self.agent_num)]
        rewards = transition_dict['rewards'].view(-1, 1)
        dones = transition_dict['dones'].float().view(-1, 1).to(self.device)

        next_joint_states = transition_dict['next_states']
        next_global_states = next_joint_states.view(episode_length, -1)

        td_target = rewards + self.gamma * self.critic(next_global_states) * (1 - dones)
        td_delta = td_target - self.critic(global_states)
        advantage = td_delta.detach()
        # advantage = compute_advantage(self.gamma, 0.95, td_delta.cpu(), dones).to(self.device)

        states = [joint_states[:, i, :] for i in range(self.agent_num)]

        for i in range(self.agent_num):
            probs = self.actors[i](states[i])
            dist = torch.distributions.Categorical(probs)
            log_probs = dist.log_prob(actions[i].squeeze()).view(-1, 1)
            actor_loss = -torch.mean(log_probs * advantage.detach())

            if self.entropy_soft:
                entropy = dist.entropy().mean()
                actor_loss -= 0.01 * entropy

            self.actor_optimizers[i].zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actors[i].parameters(), max_norm=40.0)
            self.actor_optimizers[i].step()

        critic_loss = F.mse_loss(self.critic(global_states), td_target.detach())
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=40.0)
        self.critic_optimizer.step()
