import torch
import torch.nn.functional as F
import numpy as np
import random
from itertools import combinations
import copy
from ComputeCredits_updated2 import compute_credits
from Network import PolicyNet, ValueNet
from util import compute_advantage


class MAPPO:
    def __init__(self, agent_num, state_dim_list, hidden_dim, action_num_list,
                 actor_lr, critic_lr, epochs, eps, gamma, device,
                 sample_size, entropy_soft=False):
        self.actors = [PolicyNet(state_dim_list[i], hidden_dim, action_num_list[i]).to(device)
                       for i in range(agent_num)]
        self.critic = ValueNet(sum(state_dim_list) + agent_num, hidden_dim).to(device)
        self.actor_optimizers = [torch.optim.AdamW(actor.parameters(), lr=actor_lr)
                                 for actor in self.actors]
        self.critic_optimizer = torch.optim.AdamW(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma
        self.epochs = epochs
        self.eps = eps
        self.device = device
        self.agent_num = agent_num
        self.sample_size = sample_size
        self.entropy_soft = entropy_soft
        self.lmbda = 0.95

        self.coalition_set = sum(
            [list(map(list, combinations(range(agent_num), i))) for i in range(1, agent_num + 1)],
            []
        )
        self.coalition_set = [subset for subset in self.coalition_set if len(subset) != agent_num]

    def take_action(self, state_list):
        action_list = []
        with torch.no_grad():
            for i in range(self.agent_num):
                state = state_list[i].to(self.device)
                probs = self.actors[i](state)
                action_dist = torch.distributions.Categorical(probs)
                action = action_dist.sample()
                action_list.append(action)
        return action_list

    def update(self, transition_dict):
        joint_states = transition_dict['states']
        buffer_size = joint_states.shape[0]
        global_states = joint_states.view(buffer_size, -1)
        joint_actions = transition_dict['actions'].view(buffer_size, -1)
        next_joint_actions = torch.roll(joint_actions, shifts=-1, dims=0)
        next_joint_states = transition_dict['next_states']
        next_global_states = next_joint_states.view(buffer_size, -1)
        rewards = transition_dict['rewards'].view(-1, 1)
        dones = transition_dict['dones'].float().view(-1, 1).to(self.device)
        actions = [joint_actions[:, i].view(-1, 1) for i in range(self.agent_num)]

        input = torch.cat([global_states, joint_actions], dim=1)
        next_input = torch.cat([next_global_states, next_joint_actions], dim=1)

        with torch.no_grad():
            td_target = rewards + self.gamma * self.critic(next_input) * (1 - dones)

        probs = []
        max_actions = []
        with torch.no_grad():
            for i in range(self.agent_num):
                states_i = joint_states[:, i, :]
                probs_i = self.actors[i](states_i)
                probs.append(probs_i)
                max_actions_i = torch.argmax(probs_i, dim=1).view(-1, 1)
                max_actions.append(max_actions_i)

            max_actions_torch = torch.cat(max_actions, dim=1)
            max_input = torch.cat([global_states, max_actions_torch], dim=1)
            value_base = self.critic(max_input)

            sampled_coalition_set = random.sample(self.coalition_set, self.sample_size)
            advantage = []
            for C in sampled_coalition_set:
                actions_C = copy.deepcopy(actions)
                for i in range(self.agent_num):
                    if i not in C:
                        actions_C[i] = copy.deepcopy(max_actions[i])
                actions_C = torch.cat(actions_C, dim=1)
                input_C = torch.cat([global_states, actions_C], dim=1)
                value_C = self.critic(input_C).detach()
                advantage_C = value_C - value_base
                advantage.append(advantage_C.cpu().numpy())

        td_delta = td_target - value_base
        advantage_grand = compute_advantage(self.gamma, self.lmbda, td_delta, dones)
        advantage_grand_np = advantage_grand.cpu().numpy()
        advantage_np = np.array(advantage)
        solution = compute_credits(self.agent_num, sampled_coalition_set, advantage_np, advantage_grand_np)
        credit = [torch.tensor(solution[:, i]).view(-1, 1).to(self.device) for i in range(self.agent_num)]

        for i in range(self.agent_num):
            states_i = joint_states[:, i, :]
            old_log_probs = torch.log(self.actors[i](states_i).gather(1, actions[i])).detach()
            for _ in range(self.epochs):
                log_probs = torch.log(self.actors[i](states_i).gather(1, actions[i]))
                ratio = torch.exp(log_probs - old_log_probs)
                credit_i = self.agent_num * credit[i]
                surr1 = ratio * credit_i.detach()
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * credit_i.detach()
                actor_loss = torch.mean(-torch.min(surr1, surr2))

                if self.entropy_soft:
                    entropy = -(self.actors[i](states_i) * torch.log(self.actors[i](states_i) + 1e-8)).sum(dim=1).mean()
                    actor_loss = actor_loss - 0.01 * entropy

                self.actor_optimizers[i].zero_grad()
                actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actors[i].parameters(), max_norm=40.0)
                self.actor_optimizers[i].step()

        for _ in range(self.epochs):
            critic_loss = F.mse_loss(self.critic(input), td_target.detach())
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=40.0)
            self.critic_optimizer.step()
