import numpy as np
import torch
import torch.nn as nn
from macpo.utils.util import get_gard_norm, huber_loss, mse_loss
from macpo.utils.popart import PopArt
from macpo.algorithms.r_mappo.algorithm.r_actor_critic import R_Actor
from torch.nn.utils import clip_grad_norm
import copy
from macpo.algorithms.utils.util import check
#from memory_profiler import profile

# EPS = 1e-8

class MAFOCOPS():
    """
    Trainer class for MATRPO to update policies.
    :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.
    :param policy: (R_MAPPO_Policy) policy to update.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """

    def __init__(self,
                 args,
                 policy,
                 device=torch.device("cpu")):

        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy

        self.clip_param =  args.clip_param
        self.ppo_epoch = args.ppo_epoch
        self.num_mini_batch = args.num_mini_batch
        self.data_chunk_length = args.data_chunk_length
        self.value_loss_coef = args.value_loss_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm
        self.huber_delta = args.huber_delta
        self.episode_length = args.episode_length

        self.kl_threshold = args.kl_threshold
        self.EPS = args.EPS

        self._use_recurrent_policy = args.use_recurrent_policy
        self._use_naive_recurrent = args.use_naive_recurrent_policy
        self._use_max_grad_norm = args.use_max_grad_norm
        self._use_clipped_value_loss = args.use_clipped_value_loss
        self._use_huber_loss = args.use_huber_loss
        self._use_popart = args.use_popart
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks

        # focops parameter
        self.focops_lam = args.focops_lam
        self.nu = 0.0
        self.eta = args.focops_eta
        self.nu_lr = args.nu_lr
        self.nu_max = args.nu_max
        self.cost_limit = args.safety_bound
        self.train_pi_iteration = args.ls_step
        self._use_kl_early_stopping =args.use_kl_early_stopping
        self.mini_batch = args.batch_size
        self.factor = args.use_factor
        self.nu_infity = args.nu_infity

        # todo:  my args-start
        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy

        if self._use_popart:
            self.value_normalizer = PopArt(1, device=self.device)
        else:
            self.value_normalizer = None

    def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):
        """
        Calculate value function loss.
        :param values: (torch.Tensor) value function predictions.
        :param value_preds_batch: (torch.Tensor) "old" value  predictions from data batch (used for value clip loss)
        :param return_batch: (torch.Tensor) reward to go returns.
        :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.

        :return value_loss: (torch.Tensor) value function loss.
        """
        if self._use_popart:       # whether normalize
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,
                                                                                        self.clip_param)
            error_clipped = self.value_normalizer(return_batch) - value_pred_clipped
            error_original = self.value_normalizer(return_batch) - values
        else:
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,
                                                                                        self.clip_param)
            error_clipped = return_batch - value_pred_clipped
            error_original = return_batch - values

        if self._use_huber_loss:
            value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
            value_loss_original = huber_loss(error_original, self.huber_delta)
        else:
            value_loss_clipped = mse_loss(error_clipped)
            value_loss_original = mse_loss(error_original)

        if self._use_clipped_value_loss:
            value_loss = torch.max(value_loss_original, value_loss_clipped)
        else:
            value_loss = value_loss_original

        if self._use_value_active_masks:
            value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            value_loss = value_loss.mean()

        return value_loss

    def flat_grad(self, grads):
        grad_flatten = []
        for grad in grads:
            if grad is None:
                continue
            grad_flatten.append(grad.view(-1))
        grad_flatten = torch.cat(grad_flatten)
        return grad_flatten

    def flat_params(self, model):
        params = []
        for param in model.parameters():
            params.append(param.data.view(-1))
        params_flatten = torch.cat(params)
        return params_flatten

    def update_model(self, model, new_params):
        index = 0
        for params in model.parameters():
            params_length = len(params.view(-1))
            new_param = new_params[index: index + params_length]
            new_param = new_param.view(params.size())
            params.data.copy_(new_param)
            index += params_length

    def kl_divergence(self, obs, rnn_states, action, masks, available_actions, active_masks, new_actor, old_actor):

        _, _, mu, std = new_actor.evaluate_actions(obs, rnn_states, action, masks, available_actions, active_masks)
        _, _, mu_old, std_old = old_actor.evaluate_actions(obs, rnn_states, action, masks, available_actions,
                                                           active_masks)
        logstd = torch.log(std)
        mu_old = mu_old.detach()
        std_old = std_old.detach()
        logstd_old = torch.log(std_old)

        # kl divergence between old policy and new policy : D( pi_old || pi_new )
        # pi_old -> mu0, logstd0, std0 / pi_new -> mu, logstd, std
        # be careful of calculating KL-divergence. It is not symmetric metric
        kl = logstd_old - logstd + (std.pow(2) + (mu_old - mu).pow(2)) / \
             (self.EPS + 2.0 * std_old.pow(2)) - 0.5       # regard the distribution as Guassian distribution

        return kl.sum(1, keepdim=True)

    def update_lagrange_multiplier(self, ep_costs):
        ep_costs = ep_costs.mean().item()
        self.nu += self.nu_lr * (ep_costs - self.cost_limit)
        if self.nu < 0.0:
            self.nu = 0.0
        elif not self.nu_infity:
            if self.nu > self.nu_max:
                self.nu = self.nu_max

    def trpo_update(self, sample, update_actor=True):
        """
        Update actor and critic networks.
        :param sample: (Tuple) contains data batch with which to update networks.
        :update_actor: (bool) whether to update actor network.

        :return value_loss: (torch.Tensor) value function loss.
        :return critic_grad_norm: (torch.Tensor) gradient norm from critic update.
        ;return policy_loss: (torch.Tensor) actor(policy) loss value.
        :return dist_entropy: (torch.Tensor) action entropies.
        :return actor_grad_norm: (torch.Tensor) gradient norm from actor update.
        :return imp_weights: (torch.Tensor) importance sampling weights.
        """
        share_obs_batch_w, obs_batch_w, rnn_states_batch_w, rnn_states_critic_batch_w, actions_batch_w, \
        value_preds_batch, return_batch, masks_batch_w, active_masks_batch_w, old_action_log_probs_batch, \
        adv_targ, available_actions_batch_w, factor_batch, cost_preds_batch_w, cost_returns_barch_w, rnn_states_cost_batch_w, \
        cost_adv_targ, aver_episode_costs = sample

        share_obs_batch_w = check(share_obs_batch_w).to(**self.tpdv)
        obs_batch_w = check(obs_batch_w).to(**self.tpdv)
        rnn_states_batch_w = check(rnn_states_batch_w).to(**self.tpdv)
        rnn_states_critic_batch_w = check(rnn_states_critic_batch_w).to(**self.tpdv)
        rnn_states_cost_batch_w = check(rnn_states_cost_batch_w).to(**self.tpdv)
        actions_batch_w = check(actions_batch_w).to(**self.tpdv)

        old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
        adv_targ = check(adv_targ).to(**self.tpdv)
        cost_adv_targ = check(cost_adv_targ).to(**self.tpdv)

        value_preds_batch = check(value_preds_batch).to(**self.tpdv)
        return_batch = check(return_batch).to(**self.tpdv)
        masks_batch_w = check(masks_batch_w).to()
        active_masks_batch_w = check(active_masks_batch_w).to(**self.tpdv)
        factor_batch = check(factor_batch).to(**self.tpdv)
        cost_returns_barch_w = check(cost_returns_barch_w).to(**self.tpdv)
        cost_preds_batch_w = check(cost_preds_batch_w).to(**self.tpdv)

        ep_cost = aver_episode_costs
        if len(ep_cost.shape) == 0:
            self.update_lagrange_multiplier(ep_cost)

        dataset = torch.utils.data.TensorDataset(share_obs_batch_w, obs_batch_w, rnn_states_batch_w, rnn_states_critic_batch_w,
                            actions_batch_w, value_preds_batch, return_batch, masks_batch_w, active_masks_batch_w,\
                    old_action_log_probs_batch, adv_targ, factor_batch, cost_preds_batch_w, cost_returns_barch_w,\
                                                 rnn_states_cost_batch_w, cost_adv_targ)
        loader = torch.utils.data.DataLoader(dataset, batch_size=self.mini_batch, shuffle=True)

        params = self.flat_params(self.policy.actor)

        old_actor = R_Actor(self.policy.args,
                          self.policy.obs_space,
                          self.policy.act_space,
                          self.device)
        self.update_model(old_actor, params)
        value_loss_num = []
        critic_grad_norm_num = []
        loss_pi_num = []
        cost_grad_norm_num = []
        pi_grad_norm_num = []
        ratio_num = []
        cost_loss_num = []

        for i in range(self.train_pi_iteration):
            for _, (share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch,
                            actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, \
                    old_action_log_probs_batch, adv_targ, factor_batch, cost_preds_batch, \
                    cost_returns_barch, rnn_states_cost_batch, cost_adv_targ) in enumerate(loader):

                values, action_log_probs, dist_entropy, cost_values, action_mu, action_std = self.policy.evaluate_actions(
                    share_obs_batch,
                    obs_batch,
                    rnn_states_batch,
                    rnn_states_critic_batch,
                    actions_batch,
                    masks_batch,
                    None,  # available_action is None
                    active_masks_batch,
                    rnn_states_cost_batch)

                # todo: reward critic update return is computed using gae
                value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)
                value_loss *= self.value_loss_coef

                value_loss_num.append(value_loss.item())
                self.policy.critic_optimizer.zero_grad()
                value_loss.backward()
                if self._use_max_grad_norm:
                    critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
                else:
                    critic_grad_norm = get_gard_norm(self.policy.critic.parameters())
                critic_grad_norm_num.append(critic_grad_norm.item())
                self.policy.critic_optimizer.step()

                # todo: cost critic update
                cost_loss = self.cal_value_loss(cost_values, cost_preds_batch, cost_returns_barch, active_masks_batch)
                cost_loss *= self.value_loss_coef
                cost_loss_num.append(cost_loss.item())
                self.policy.cost_optimizer.zero_grad()
                cost_loss.backward()
                if self._use_max_grad_norm:
                    cost_grad_norm = nn.utils.clip_grad_norm_(self.policy.cost_critic.parameters(), self.max_grad_norm)
                else:
                    cost_grad_norm = get_gard_norm(self.policy.cost_critic.parameters())
                cost_grad_norm_num.append(cost_grad_norm.item())
                self.policy.cost_optimizer.step()
                # update critic network

                kl = self.kl_divergence(obs_batch,
                                        rnn_states_batch,
                                        actions_batch,
                                        masks_batch,
                                        None,
                                        active_masks_batch,
                                        new_actor=self.policy.actor,
                                        old_actor=old_actor)

                ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
                ratio_num.append(ratio.mean().item())
                if self.factor:
                    loss_pi = (kl - (1 / self.focops_lam) * ratio * factor_batch * (adv_targ - self.nu * cost_adv_targ)) \
                              * (kl.detach() < self.eta).type(torch.float32)
                else:
                    loss_pi = (kl - (1 / self.focops_lam) * ratio * (adv_targ - self.nu * cost_adv_targ)) * \
                              (kl.detach() < self.eta).type(torch.float32)
                loss_pi = loss_pi.mean()
                loss_pi -= self.entropy_coef * dist_entropy.mean()
                loss_pi_num.append(loss_pi.item())
                self.policy.actor_optimizer.zero_grad()
                loss_pi.backward()
                if self._use_max_grad_norm:
                    pi_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
                else:
                    pi_grad_norm = get_gard_norm(self.policy.actor.parameters())
                pi_grad_norm_num.append(pi_grad_norm.item())
                self.policy.actor_optimizer.step()

            kl = self.kl_divergence(obs_batch_w,
                                    rnn_states_batch_w,
                                    actions_batch_w,
                                    masks_batch_w,
                                    available_actions_batch_w,
                                    active_masks_batch_w,
                                    new_actor=self.policy.actor,
                                    old_actor=old_actor)
            kl_mean = kl.mean()
            if self._use_kl_early_stopping:
                if kl_mean.item() > self.kl_threshold:
                    break
        value_loss = np.mean(value_loss_num)
        critic_grad_norm = np.mean(critic_grad_norm_num)
        loss_pi = np.mean(loss_pi_num)
        ratio = np.mean(ratio_num)
        cost_loss = np.mean(cost_loss_num)
        cost_grad_norm = np.mean(cost_grad_norm_num)
        pi_grad_norm = np.mean(pi_grad_norm_num)
        values, action_log_probs, dist_entropy, cost_values, action_mu, action_std = self.policy.evaluate_actions(
            share_obs_batch_w,
            obs_batch_w,
            rnn_states_batch_w,
            rnn_states_critic_batch_w,
            actions_batch_w,
            masks_batch_w,
            available_actions_batch_w,
            active_masks_batch_w,
            rnn_states_cost_batch_w)

        return value_loss, critic_grad_norm, kl_mean, loss_pi, dist_entropy, ratio, cost_loss, cost_grad_norm, cost_preds_batch_w, cost_returns_barch_w, action_mu, action_std, self.nu, pi_grad_norm

    #@profile(precision=4, stream=open('mafocops.log','w+'))       monitor the memory usage
    def train(self, buffer, shared_buffer=None, update_actor=True):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.

        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        if self._use_popart:
            advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1])
        else:
            advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
        advantages_copy = advantages.copy()
        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
        mean_advantages = np.nanmean(advantages_copy)
        std_advantages = np.nanstd(advantages_copy)
        advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)

        if self._use_popart:
            cost_adv = buffer.cost_returns[:-1] - self.value_normalizer.denormalize(buffer.cost_preds[:-1])
        else:
            cost_adv = buffer.cost_returns[:-1] - buffer.cost_preds[:-1]
        cost_adv_copy = cost_adv.copy()
        cost_adv_copy[buffer.active_masks[:-1] == 0.0] = np.nan
        mean_cost_adv = np.nanmean(cost_adv_copy)
        std_cost_adv = np.nanstd(cost_adv_copy)
        cost_adv = (cost_adv - mean_cost_adv) / (std_cost_adv + 1e-5)

        train_info = {}

        train_info['value_loss'] = 0
        train_info['kl'] = 0
        train_info['dist_entropy'] = 0
        train_info['loss_pi'] = 0
        train_info['critic_grad_norm'] = 0
        train_info['ratio'] = 0
        train_info['cost_loss'] = 0
        train_info['cost_grad_norm'] = 0
        train_info['pi_grad_norm'] = 0
        train_info['cost_preds_batch'] = 0
        train_info['cost_returns_barch'] = 0
        train_info['nu'] = 0
        train_info['action_mu'] = 0
        train_info['action_std'] = 0

        if self._use_recurrent_policy:
            data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length,
                                                        cost_adv=cost_adv)
        elif self._use_naive_recurrent:
            data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch, cost_adv=cost_adv)
        else:
            data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch, cost_adv=cost_adv)
        # old_actor = copy.deepcopy(self.policy.actor)
        for sample in data_generator:
            value_loss, critic_grad_norm, kl, loss_pi, dist_entropy, ratio, cost_loss, cost_grad_norm, cost_preds_batch, \
            cost_returns_barch, action_mu, action_std, nu, pi_grad_norm = self.trpo_update(sample, update_actor)

            train_info['value_loss'] += value_loss.item()
            train_info['kl'] += kl
            train_info['loss_pi'] += loss_pi
            train_info['dist_entropy'] += dist_entropy.item()
            train_info['critic_grad_norm'] += critic_grad_norm
            train_info['pi_grad_norm'] += pi_grad_norm
            train_info['ratio'] += ratio.mean()
            train_info['cost_loss'] += value_loss.item()
            train_info['cost_grad_norm'] += cost_grad_norm

            train_info['cost_preds_batch'] += cost_preds_batch.mean()
            train_info['cost_returns_barch'] += cost_returns_barch.mean()
            train_info['nu'] += nu

            train_info['action_mu'] += action_mu.float().mean()
            train_info['action_std'] += action_std.float().mean()

        num_updates = self.ppo_epoch * self.num_mini_batch

        for k in train_info.keys():
            train_info[k] /= num_updates

        return train_info

    def prep_training(self):
        self.policy.actor.train()
        self.policy.critic.train()

    def prep_rollout(self):
        self.policy.actor.eval()
        self.policy.critic.eval()

