import torch
import torch.nn.functional as F
import numpy as np
from Network import PolicyNet, ValueNet
from util import compute_advantage


class MAPPO:
    def __init__(self, agent_num, state_dim_list, hidden_dim, action_num_list, actor_lr, critic_lr,
                 epochs, eps, gamma, device, sample_size, entropy_soft=False):
        self.actors = [PolicyNet(state_dim_list[i], hidden_dim, action_num_list[i]).to(device) for i in range(agent_num)]
        self.critic = ValueNet(sum(state_dim_list), hidden_dim).to(device)
        self.actor_optimizers = [torch.optim.AdamW(actor.parameters(), lr=actor_lr) for actor in self.actors]
        self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma
        self.epochs = epochs
        self.eps = eps
        self.device = device
        self.agent_num = agent_num
        self.sample_size = sample_size
        self.entropy_soft = entropy_soft

    def take_action(self, state_list):
        action_list = []
        with torch.no_grad():
            for i in range(self.agent_num):
                state = state_list[i].to(self.device)  # [num_envs, state_dim]
                probs = self.actors[i](state)  # 1x1 6x128
                action_dist = torch.distributions.Categorical(probs)
                action = action_dist.sample()
                action_list.append(action)
        return action_list

    def update(self, transition_dict):
        joint_states = transition_dict['states'] # [episode_length, agt_num, agt_state_dim]
        episode_length = joint_states.shape[0]
        global_states = joint_states.view(episode_length, -1)

        # joint_actions = torch.tensor(transition_dict['actions']).to(self.device)
        # next_joint_actions = torch.roll(joint_actions, shifts=-1, dims=0)

        actions = []
        # next_actions = []
        for i in range(self.agent_num):
            actions_i = transition_dict['actions'][:, i]
            actions_i_torch = actions_i.view(-1, 1)
            # next_actions_i_torch = torch.roll(actions_i_torch, shifts=-1, dims=0)
            # next_actions.append(next_actions_i_torch)
            actions.append(actions_i_torch)

        rewards = transition_dict['rewards'].view(-1, 1)
        next_joint_states = transition_dict['next_states']
        next_global_states = next_joint_states.view(episode_length, -1)

        dones = transition_dict['dones'].float().view(-1, 1).to(self.device)
        td_target = rewards + self.gamma * self.critic(next_global_states) * (1 - dones)

        states = []
        # next_states = []

        for i in range(self.agent_num):
            states_i = joint_states[:, i, :]  # shape = [25, 24] = [episode_length, agt_state_dim]
            states.append(states_i)
            # next_states_i = next_joint_states[:, i, :]  # shape = [25, 24] = [episode_length, agt_state_dim]
            # next_states.append(next_states_i)

        # advantage = td_target - self.critic(global_states)
        td_delta = td_target - self.critic(global_states)
        advantage = compute_advantage(self.gamma, 0.95, td_delta.cpu(), dones).to(self.device)
        advantage = advantage.clamp(min=0)
        for i in range(self.agent_num):
            old_log_probs = torch.log(self.actors[i](states[i]).gather(1, actions[i])).detach()
            for _ in range(self.epochs):
                log_probs = torch.log(self.actors[i](states[i]).gather(1, actions[i]))
                ratio = torch.exp(log_probs - old_log_probs)
                surr1 = ratio * advantage.detach()
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage.detach()
                actor_loss = torch.mean(-torch.min(surr1, surr2))

                if self.entropy_soft == True:
                    entropy = -torch.sum(self.actors[i](states[i]) * torch.log(self.actors[i](states[i])), dim=1).mean()
                    actor_loss = actor_loss - 0.01 * entropy

                self.actor_optimizers[i].zero_grad()
                actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actors[i].parameters(), max_norm=40.0)
                self.actor_optimizers[i].step()

        # for i in range(self.agent_num):
        for _ in range(self.epochs):
            critic_loss = torch.mean(F.mse_loss(self.critic(global_states), td_target.detach()))
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=40.0)
            self.critic_optimizer.step()