import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations
import random
import copy
from ComputeCredits_updated import compute_credits
from Network import PolicyNetContinuous, ValueNet


def compute_advantage(gamma, lmbda, td_delta, dones):
    td_delta = td_delta.detach()
    advantage = torch.zeros_like(td_delta)
    gae = 0.0
    for t in reversed(range(len(td_delta))):
        if dones[t, 0] == 1.0:
            gae = 0.0
        gae = td_delta[t, 0] + gamma * lmbda * gae
        advantage[t, 0] = gae
    return advantage


class PPOContinuous:
    def __init__(self, agent_num, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 lmbda, epochs, eps, gamma, sample_size, bound, device):
        self.agent_num = agent_num
        self.actors = [PolicyNetContinuous(state_dim, hidden_dim, action_dim, bound).to(device) for _ in range(self.agent_num)]
        self.critic = ValueNet(state_dim * agent_num + action_dim * agent_num, hidden_dim).to(device)
        self.critic_v = ValueNet(state_dim * agent_num, hidden_dim).to(device)
        self.actor_optimizers = [torch.optim.Adam(actor.parameters(), lr=actor_lr, weight_decay=1e-5) for actor in self.actors]
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr, weight_decay=1e-5)
        self.critic_v_optimizer = torch.optim.Adam(self.critic_v.parameters(), lr=critic_lr, weight_decay=1e-5)
        self.gamma = gamma
        self.lmbda = lmbda
        self.epochs = epochs
        self.eps = eps
        self.device = device
        self.action_dim = action_dim
        coalition_set = sum([list(map(list, combinations(list(range(agent_num)), i))) for i in range(1, agent_num + 1)], [])
        self.coalition_set = [subset for subset in coalition_set if len(subset) != agent_num]
        self.sample_size = sample_size

    def take_action_async(self, obs, eval=False):
        obs_torch = torch.tensor(obs, dtype=torch.float)
        with torch.no_grad():
            if eval:
                mu = torch.zeros([self.agent_num, self.action_dim])
                for i in range(self.agent_num):
                    mu_i, sigma_i = self.actors[i](obs_torch[i])
                    mu[i, :] = mu_i
                action = mu.cpu().numpy()
            else:
                mu = torch.zeros([obs_torch.shape[0], self.agent_num, self.action_dim])
                sigma = torch.zeros([obs_torch.shape[0], self.agent_num, self.action_dim])
                for i in range(self.agent_num):
                    mu_i, sigma_i = self.actors[i](obs_torch[:, i, :])
                    mu[:, i, :] = mu_i
                    sigma[:, i, :] = sigma_i
                action_dist = torch.distributions.Normal(mu, sigma)
                action = action_dist.sample()
                action = action.cpu().numpy()
        return action

    def take_action(self, obs, eval=False):
        obs_torch = torch.tensor(obs, dtype=torch.float)
        with torch.no_grad():
            if eval:
                mu = torch.zeros([self.agent_num, self.action_dim])
                for i in range(self.agent_num):
                    mu_i, sigma_i = self.actors[i](obs_torch[i])
                    mu[i, :] = mu_i
                action = mu.cpu().numpy()
            else:
                mu = torch.zeros([obs_torch.shape[0], self.action_dim])
                sigma = torch.zeros([obs_torch.shape[0], self.action_dim])
                for i in range(self.agent_num):
                    mu_i, sigma_i = self.actors[i](obs_torch[i])
                    mu[i, :] = mu_i
                    sigma[i, :] = sigma_i
                action_dist = torch.distributions.Normal(mu, sigma)
                action = action_dist.sample()
                action = action.cpu().numpy()
        return action

    def update(self, transition_dict):
        states = transition_dict['states']
        memory_length = states.shape[0]
        actions = transition_dict['actions']
        rewards = transition_dict['rewards'].view(-1, 1)
        dones = transition_dict['dones'].float().view(-1, 1)
        next_states = transition_dict['next_states']

        next_actions = torch.roll(actions, shifts=-1, dims=0)
        joint_states = states.view(memory_length, -1)
        next_joint_states = next_states.view(memory_length, -1)
        joint_actions = actions.view(memory_length, -1)
        next_joint_actions = next_actions.view(memory_length, -1)

        input_critic = torch.cat([joint_states, joint_actions], dim=1)

        with torch.no_grad():
            mu = torch.zeros([memory_length, self.agent_num, self.action_dim])
            sigma = torch.zeros([memory_length, self.agent_num, self.action_dim])
            for i in range(self.agent_num):
                mu_i, sigma_i = self.actors[i](states[:, i, :])
                mu[:, i, :] = mu_i
                sigma[:, i, :] = sigma_i

            td_target = rewards + self.gamma * self.critic_v(next_joint_states) * (1 - dones)
            value_base = self.critic_v(joint_states)
            td_delta = td_target - value_base

            grand_advantage = compute_advantage(self.gamma, self.lmbda, td_delta, dones).to(self.device)
            sampled_coalition_set = random.sample(self.coalition_set, self.sample_size)
            advantage = []
            for C in sampled_coalition_set:
                actions_C = copy.deepcopy(actions)
                for i in range(self.agent_num):
                    if i not in C:
                        actions_C[:, i, :] = mu[:, i, :].detach()
                input_C = torch.cat([joint_states, actions_C.view(memory_length, -1)], dim=1)
                value_C = self.critic(input_C).detach()
                advantage_C = value_C - value_base
                advantage.append(advantage_C.cpu().numpy())

        advantage_grand_np = grand_advantage.cpu().numpy()
        advantage_np = np.array(advantage)
        solution = compute_credits(self.agent_num, sampled_coalition_set, advantage_np, advantage_grand_np)
        credit = torch.tensor(solution).to(self.device)

        for i in range(self.agent_num):
            obs_i = states[:, i, :]
            mu_i, sigma_i = self.actors[i](obs_i)
            action_dists = torch.distributions.Normal(mu_i.detach(), sigma_i.detach())
            action_i = actions[:, i, :]
            old_log_probs = action_dists.log_prob(action_i).sum(-1, keepdims=True)
            for _ in range(self.epochs):
                mu_i, sigma_i = self.actors[i](obs_i)
                action_dists = torch.distributions.Normal(mu_i, sigma_i)
                log_probs = action_dists.log_prob(action_i).sum(-1, keepdims=True)
                ratio = torch.exp(log_probs - old_log_probs)
                surr1 = ratio * (credit[:, i].view(-1, 1))
                surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * (credit[:, i].view(-1, 1))
                actor_loss = torch.mean(-torch.min(surr1, surr2))
                entropy = action_dists.entropy().mean()
                actor_loss -= 0.001 * entropy
                self.actor_optimizers[i].zero_grad()
                actor_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.actors[i].parameters(), 40.0)
                self.actor_optimizers[i].step()

        for _ in range(self.epochs):
            critic_v_loss = torch.mean(F.mse_loss(self.critic_v(joint_states), td_target.detach()))
            critic_loss = torch.mean(F.mse_loss(self.critic(input_critic), td_target.detach()))
            self.critic_v_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            critic_v_loss.backward()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic_v.parameters(), 40.0)
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 40.0)
            self.critic_v_optimizer.step()
            self.critic_optimizer.step()
