import torch
import torch.nn.functional as F
import numpy as np
from Network import PolicyNet, ValueNet
from util import compute_advantage


class MAPPO:
    def __init__(self, agent_num, state_dim_list, hidden_dim, action_num_list, actor_lr, critic_lr,
                 epochs, eps, gamma, device, sample_size, entropy_soft=False):
        self.actors = [PolicyNet(state_dim_list[i], hidden_dim, action_num_list[i]).to(device) for i in range(agent_num)]
        self.critic = ValueNet(sum(state_dim_list), hidden_dim).to(device)
        self.actor_optimizers = [torch.optim.AdamW(actor.parameters(), lr=actor_lr) for actor in self.actors]
        self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma
        self.epochs = epochs
        self.eps = eps
        self.device = device
        self.agent_num = agent_num
        self.sample_size = sample_size
        self.entropy_soft = entropy_soft

    def take_action(self, state_list):
        action_list = []
        with torch.no_grad():
            for i in range(self.agent_num):
                state = state_list[i].to(self.device)
                probs = self.actors[i](state)
                action_dist = torch.distributions.Categorical(probs)
                action = action_dist.sample()
                action_list.append(action)
        return action_list

    def update(self, transition_dict):
        joint_states = transition_dict['states']
        episode_length = joint_states.shape[0]
        global_states = joint_states.view(episode_length, -1)

        actions = []
        for i in range(self.agent_num):
            actions_i = transition_dict['actions'][:, i]
            actions_i_torch = actions_i.view(-1, 1)
            actions.append(actions_i_torch)

        rewards = transition_dict['rewards'].view(-1, 1)
        next_joint_states = transition_dict['next_states']
        next_global_states = next_joint_states.view(episode_length, -1)

        dones = transition_dict['dones'].float().view(-1, 1).to(self.device)
        td_target = rewards + self.gamma * self.critic(next_global_states) * (1 - dones)

        states = []
        for i in range(self.agent_num):
            states_i = joint_states[:, i, :]
            states.append(states_i)

        td_delta = td_target - self.critic(global_states)
        advantage = compute_advantage(self.gamma, 0.95, td_delta.cpu(), dones).to(self.device)
        for i in range(self.agent_num):
            old_log_probs = torch.log(self.actors[i](states[i]).gather(1, actions[i])).detach()
            for _ in range(self.epochs):
                log_probs = torch.log(self.actors[i](states[i]).gather(1, actions[i]))
                ratio = torch.exp(log_probs - old_log_probs)
                surr1 = ratio * advantage.detach()
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage.detach()
                actor_loss = torch.mean(-torch.min(surr1, surr2))

                if self.entropy_soft:
                    entropy = -torch.sum(self.actors[i](states[i]) * torch.log(self.actors[i](states[i])), dim=1).mean()
                    actor_loss = actor_loss - 0.01 * entropy

                self.actor_optimizers[i].zero_grad()
                actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actors[i].parameters(), max_norm=40.0)
                self.actor_optimizers[i].step()

        for _ in range(self.epochs):
            critic_loss = torch.mean(F.mse_loss(self.critic(global_states), td_target.detach()))
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=40.0)
            self.critic_optimizer.step()
