import numpy as np
import torch
import torch.nn as nn
from utils.util import get_gard_norm, huber_loss, mse_loss
from torch.nn import functional as F
from utils.valuenorm import ValueNorm
from algorithms.utils.util import check


class MATTrainer:
    """
    Trainer class for MAT to update policies.
    :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.
    :param policy: (R_MAPPO_Policy) policy to update.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self,
                 args,
                 policy,
                 num_agents,
                 device=torch.device("cpu")):

        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy
        self.num_agents = num_agents

        self.clip_param = args.clip_param
        self.ppo_epoch = args.ppo_epoch
        self.epoch_disc = args.epoch_disc
        self.num_mini_batch = args.num_mini_batch
        self.data_chunk_length = args.data_chunk_length
        self.value_loss_coef = args.value_loss_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm
        self.huber_delta = args.huber_delta

        self._use_recurrent_policy = args.use_recurrent_policy
        self._use_naive_recurrent = args.use_naive_recurrent_policy
        self._use_max_grad_norm = args.use_max_grad_norm
        self._use_clipped_value_loss = args.use_clipped_value_loss
        self._use_huber_loss = args.use_huber_loss
        self._use_valuenorm = args.use_valuenorm
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks
        self.dec_actor = args.dec_actor
        self._all_data_length = args.n_rollout_threads * args.episode_length
        # use for gail
        self._use_gail = args.use_gail
        self._fix_encoder = args.fix_encoder
        self._disc_batch_size = args.disc_batch_size
        self._disc_warm_up = args.disc_warm_up
        self._disc_warm_up_round = args.disc_warm_up_round
        self._disc_warm_up_epoch = args.disc_warm_up_epoch
        self._disc_use_act_prob = args.disc_use_act_prob
        ## use for disc update delay and decay
        self._use_disc_early_stop = args.use_disc_early_stop
        self._use_disc_lr_decay = args.use_disc_lr_decay
        self._disc_stop_acc = args.disc_stop_acc
        self._disc_stop_round = args.disc_stop_round
        self._disc_restart_type = args.disc_restart_type
        self._disc_restart_acc = args.disc_restart_acc
        ## use for wail to train disc
        self._use_wail = args.use_wail
        self._use_disc_grad_penalty = args.use_disc_grad_penalty
        ## add for encoder-decoder disc(decentralization)
        self._disc_use_decoder = args.disc_use_decoder
        self._disc_cal_last_loss = args.disc_cal_last_loss

        if self._use_valuenorm:
            self.value_normalizer = ValueNorm(1, device=self.device)
        else:
            self.value_normalizer = None

    def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):
        """
        Calculate value function loss.
        :param values: (torch.Tensor) value function predictions.
        :param value_preds_batch: (torch.Tensor) "old" value  predictions from data batch (used for value clip loss)
        :param return_batch: (torch.Tensor) reward to go returns.
        :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.

        :return value_loss: (torch.Tensor) value function loss.
        """

        value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,
                                                                                    self.clip_param)

        if self._use_valuenorm:
            self.value_normalizer.update(return_batch)
            error_clipped = self.value_normalizer.normalize(return_batch) - value_pred_clipped
            error_original = self.value_normalizer.normalize(return_batch) - values
        else:
            error_clipped = return_batch - value_pred_clipped
            error_original = return_batch - values

        if self._use_huber_loss:
            value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
            value_loss_original = huber_loss(error_original, self.huber_delta)
        else:
            value_loss_clipped = mse_loss(error_clipped)
            value_loss_original = mse_loss(error_original)

        if self._use_clipped_value_loss:
            value_loss = torch.max(value_loss_original, value_loss_clipped)
        else:
            value_loss = value_loss_original

        # if self._use_value_active_masks and not self.dec_actor:
        if self._use_value_active_masks:
            value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            value_loss = value_loss.mean()

        return value_loss

    def cal_grad_pen(self, exp_share_obs, exp_obs, exp_actions, share_obs, obs, actions, lambda_=10):
        # unify data type and copy to cuda
        exp_share_obs, exp_obs, exp_actions, share_obs, obs, actions = \
            check(exp_share_obs).to(**self.tpdv), check(exp_obs).to(**self.tpdv), check(exp_actions).to(**self.tpdv), \
            check(share_obs).to(**self.tpdv), check(obs).to(**self.tpdv), check(actions).to(**self.tpdv)
        # change actions to onehot for discrete actions before feed to model
        if self.policy.action_type == 'Discrete' and not self._disc_use_act_prob:
            exp_actions = F.one_hot(exp_actions.squeeze(-1).long(), num_classes=self.policy.act_dim).float()
            actions = F.one_hot(actions.squeeze(-1).long(), num_classes=self.policy.act_dim).float()
        # get random mix up rate
        alpha = torch.rand(exp_obs.size(0), 1).to(**self.tpdv)
        alpha_share_obs = alpha.expand_as(exp_share_obs)
        alpha_obs = alpha.expand_as(exp_obs)
        alpha_act = alpha.expand_as(exp_actions)
        # mix up expert data and policy data
        mixup_share_obs = alpha_share_obs * exp_share_obs + (1 - alpha_share_obs) * share_obs
        mixup_obs = alpha_obs * exp_obs + (1 - alpha_obs) * obs
        mixup_actions = alpha_act * exp_actions + (1 - alpha_act) * actions
        # set calculate grad enable for input data
        mixup_share_obs.requires_grad = True
        mixup_obs.requires_grad = True
        mixup_actions.requires_grad = True
        # print('-------- grad penalty --------')
        # print('share_obs', share_obs.shape)
        # print('exp_share_obs', exp_share_obs.shape)
        # print('mixup_share_obs', mixup_share_obs.shape)
        # print('obs', obs.shape)
        # print('exp_obs', exp_obs.shape)
        # print('mixup_obs', mixup_obs.shape)
        # print('actions', actions.shape)
        # print('exp_actions', exp_actions.shape)
        # print('mixup_actions', mixup_actions.shape)
        disc = self.policy.get_discriminator_logit(mixup_share_obs, mixup_obs, mixup_actions)
        ones = torch.ones(disc.size()).to(disc.device)
        # not cal share_obs grad because share_ons is not used
        grad = torch.autograd.grad(
            outputs=disc,
            inputs=(mixup_obs, mixup_actions),
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
            allow_unused=True)
        grad = torch.cat([g for g in grad], dim=-1)
        grad_pen = lambda_ * (grad.norm(2, dim=-1) - 1).pow(2).mean()

        return grad_pen

    def ppo_update(self, sample):
        """
        Update actor and critic networks.
        :param sample: (Tuple) contains data batch with which to update networks.
        :update_actor: (bool) whether to update actor network.
        :return value_loss: (torch.Tensor) value function loss.
        :return critic_grad_norm: (torch.Tensor) gradient norm from critic up9date.
        ;return policy_loss: (torch.Tensor) actor(policy) loss value.
        :return dist_entropy: (torch.Tensor) action entropies.
        :return actor_grad_norm: (torch.Tensor) gradient norm from actor update.
        :return imp_weights: (torch.Tensor) importance sampling weights.
        """
        """
        share_obs_batch (buffer_size * agent_num, share_obs_dim)
        obs_batch (buffer_size * agent_num, obs_dim)
        his_obs_batch (buffer_size * agent_num, his_len, obs_dim)
        actions_batch (buffer_size * agent_num, act_dim)
        """
        share_obs_batch, obs_batch, his_obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, old_action_probs_batch, \
        adv_targ, available_actions_batch = sample

        # print('------------------ ppo epoch ---------------------')
        # print('share_obs_batch', share_obs_batch.shape)
        # print('obs_batch', obs_batch.shape)
        # print('his_obs_batch', his_obs_batch.shape)
        # print('actions_batch', actions_batch.shape)

        old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
        adv_targ = check(adv_targ).to(**self.tpdv)
        value_preds_batch = check(value_preds_batch).to(**self.tpdv)
        return_batch = check(return_batch).to(**self.tpdv)
        active_masks_batch = check(active_masks_batch).to(**self.tpdv)

        # Reshape to do in a single forward pass for all steps
        values, action_log_probs, dist_entropy = self.policy.evaluate_actions(share_obs_batch,
                                                                              obs_batch,
                                                                              his_obs_batch,
                                                                              rnn_states_batch,
                                                                              rnn_states_critic_batch,
                                                                              actions_batch,
                                                                              masks_batch,
                                                                              available_actions_batch,
                                                                              active_masks_batch)
        # actor update
        imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)

        surr1 = imp_weights * adv_targ
        surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ

        if self._use_policy_active_masks:
            policy_loss = (-torch.sum(torch.min(surr1, surr2),
                                      dim=-1,
                                      keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            policy_loss = -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True).mean()

        # critic update
        value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)

        loss = policy_loss - dist_entropy * self.entropy_coef + value_loss * self.value_loss_coef

        self.policy.optimizer_decoder.zero_grad()
        self.policy.optimizer_critic.zero_grad()
        if not self._fix_encoder:
            self.policy.optimizer_encoder.zero_grad()

        loss.backward()

        if self._use_max_grad_norm:
            # grad_norm = nn.utils.clip_grad_norm_(self.policy.transformer.parameters(), self.max_grad_norm)
            critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.transformer.critic.parameters(), self.max_grad_norm)
            actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.transformer.decoder.parameters(), self.max_grad_norm)
        else:
            critic_grad_norm = get_gard_norm(self.policy.transformer.critic.parameters())
            actor_grad_norm = get_gard_norm(self.policy.transformer.decoder.parameters())

        self.policy.optimizer_decoder.step()
        self.policy.optimizer_critic.step()
        if not self._fix_encoder:
            self.policy.optimizer_encoder.step()

        return value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights

    def train(self, buffer, expert_buffer):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.
        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        advantages_copy = buffer.advantages.copy()
        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
        mean_advantages = np.nanmean(advantages_copy)
        std_advantages = np.nanstd(advantages_copy)
        advantages = (buffer.advantages - mean_advantages) / (std_advantages + 1e-5)

        train_info = {}

        train_info['value_loss'] = 0
        train_info['policy_loss'] = 0
        train_info['dist_entropy'] = 0
        train_info['actor_grad_norm'] = 0
        train_info['critic_grad_norm'] = 0
        train_info['ratio'] = 0
        if self._use_gail:
            train_info['disc_loss'] = 0
            train_info['disc_grad_norm'] = 0
            train_info['disc_expert_acc'] = 0
            train_info['disc_policy_acc'] = 0
            train_info['disc_grad_pen'] = 0
        # use gail to train MAT
        if self._use_gail:
            # warm up for training disc
            disc_epoch = self._disc_warm_up_epoch if self._disc_warm_up and self.policy._update_times <= self._disc_warm_up_round else self.epoch_disc
            all_expert_score, all_policy_score = [], []
            all_expert_acc, all_policy_acc = [], []
            for _ in range(disc_epoch):
                data_generator = buffer.feed_forward_generator_transformer(
                    advantages,
                    num_mini_batch=self._all_data_length // self._disc_batch_size,
                    mini_batch_size=None,
                    force_not_shuffle=True,
                )
                for sample in data_generator:
                    share_obs, obs, _, _, _, actions, _, _, _, _, _, action_probs, _, _ = sample
                    exp_share_obs, exp_obs, exp_actions = expert_buffer.sample_batch_data(share_obs.shape[0] // self.num_agents)
                    # if disc learn action logprob or action one hot
                    policy_actions = action_probs if self.policy.action_type == 'Discrete' and self._disc_use_act_prob else actions
                    # get logits of (state, action) for policy and expert
                    expert_logits = self.policy.get_discriminator_logit(exp_share_obs, exp_obs, exp_actions)
                    policy_logits = self.policy.get_discriminator_logit(share_obs, obs, policy_actions)
                    # record policy and expert score
                    """
                    all_expert_score.append((-F.logsigmoid(expert_logits)).detach().cpu().numpy().reshape(-1))
                    all_policy_score.append((-F.logsigmoid(policy_logits)).detach().cpu().numpy().reshape(-1))
                    """
                    all_expert_score.append(self.policy.get_discriminator_rewards_from_logits(expert_logits))
                    all_policy_score.append(self.policy.get_discriminator_rewards_from_logits(policy_logits))
                    # record policy and expert acc
                    all_expert_acc.append((F.sigmoid(expert_logits) < 0.5).float().mean().detach().item())
                    all_policy_acc.append((F.sigmoid(policy_logits) > 0.5).float().mean().detach().item())
                    # only use last token to calculate loss if necessary
                    if self._disc_use_decoder and self._disc_cal_last_loss:
                        expert_logits = expert_logits[:, -1]
                        policy_logits = policy_logits[:, -1]
                    # calculate disc loss
                    if self._use_wail:
                        disc_loss = torch.mean(torch.tanh(expert_logits)) - torch.mean(torch.tanh(policy_logits))
                    else:
                        disc_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                            expert_logits, torch.zeros_like(expert_logits)
                        ) + torch.nn.functional.binary_cross_entropy_with_logits(
                            policy_logits, torch.ones_like(policy_logits)
                        )
                    # disc_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    #     expert_logits, torch.zeros_like(expert_logits)
                    # ) + torch.nn.functional.binary_cross_entropy_with_logits(
                    #     policy_logits, torch.ones_like(policy_logits)
                    # ) if not self._use_wail else \
                    #     torch.mean(torch.tanh(expert_logits)) - torch.mean(torch.tanh(policy_logits))
                    if self.policy._train_disc_flag:
                        self.policy.optimizer_discriminator.zero_grad()
                        # print('------------------------')
                        # print('disc_loss', disc_loss)
                        all_disc_loss = disc_loss
                        if self._use_disc_grad_penalty:
                            grad_pen = self.cal_grad_pen(exp_share_obs, exp_obs, exp_actions, share_obs, obs, policy_actions)
                            all_disc_loss = all_disc_loss + grad_pen
                            train_info['disc_grad_pen'] += grad_pen.item()
                        all_disc_loss.backward()
                        # clip grad norm for disc
                        if self._use_max_grad_norm:
                            disc_grad_norm = nn.utils.clip_grad_norm_(self.policy.transformer.discriminator.parameters(), self.max_grad_norm)
                        else:
                            disc_grad_norm = get_gard_norm(self.policy.transformer.discriminator.parameters())
                        train_info['disc_grad_norm'] += disc_grad_norm
                        self.policy.optimizer_discriminator.step()
                        # decay disc lr if necessary
                        if self._use_disc_lr_decay:
                            self.policy.scheduler_discriminator.step()
                    train_info['disc_loss'] += disc_loss.item()

        for _ in range(self.ppo_epoch):
            data_generator = buffer.feed_forward_generator_transformer(advantages, self.num_mini_batch)

            for sample in data_generator:

                value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights \
                    = self.ppo_update(sample)

                train_info['value_loss'] += value_loss.item()
                train_info['policy_loss'] += policy_loss.item()
                train_info['dist_entropy'] += dist_entropy.item()
                train_info['actor_grad_norm'] += actor_grad_norm
                train_info['critic_grad_norm'] += critic_grad_norm
                train_info['ratio'] += imp_weights.mean()

        ppo_updates = self.ppo_epoch * self.num_mini_batch
        disc_update = self.epoch_disc

        for k in train_info.keys():
            train_info[k] = train_info[k] / ppo_updates if 'disc' not in k else train_info[k] / disc_update
        if self._use_gail:
            all_expert_agent_mean_scores = np.mean(np.concatenate(all_expert_score), axis=0)
            all_policy_agent_mean_scores = np.mean(np.concatenate(all_policy_score), axis=0)
            train_info['all_expert_agent_mean_scores'] = {}
            train_info['all_policy_agent_mean_scores'] = {}
            for agent_i in range(all_expert_agent_mean_scores.shape[0]):
                train_info['all_expert_agent_mean_scores']['agent_' + str(agent_i)] = all_expert_agent_mean_scores[agent_i][0]
                train_info['all_policy_agent_mean_scores']['agent_' + str(agent_i)] = all_policy_agent_mean_scores[agent_i][0]
            train_info['disc_expert_score'] = {
                'max_score': np.max(np.concatenate(all_expert_score)),
                'min_score': np.min(np.concatenate(all_expert_score)),
                'mean_score': np.mean(np.concatenate(all_expert_score)),
            }
            train_info['disc_policy_score'] = {
                'max_score': np.max(np.concatenate(all_policy_score)),
                'min_score': np.min(np.concatenate(all_policy_score)),
                'mean_score': np.mean(np.concatenate(all_policy_score)),
            }
            train_info['disc_expert_acc'] = np.mean(all_expert_acc)
            train_info['disc_policy_acc'] = np.mean(all_policy_acc)
            # early stop training disc if acc reach level if necessary
            if self._use_disc_early_stop and self.policy._train_disc_flag and \
                    train_info['disc_expert_acc'] > self._disc_stop_acc and \
                    train_info['disc_policy_acc'] > self._disc_stop_acc:
                self.policy._train_disc_flag = False
            # restart training
            elif self._use_disc_early_stop and not self.policy._train_disc_flag:
                self.policy._dis_now_stop_round += 1
                # restart training disc if policy acc is low
                if self._disc_restart_type == 'low_acc' and train_info['disc_policy_acc'] < self._disc_restart_acc:
                    self.policy._dis_now_stop_round = 0
                    self.policy._train_disc_flag = True
                # restart training disc at fixed interval
                elif self._disc_restart_type == 'fix_epoch' and self.policy._dis_now_stop_round >= self._disc_stop_round:
                    self.policy._dis_now_stop_round = 0
                    self.policy._train_disc_flag = True
                # print('disc_expert_acc', train_info['disc_expert_acc'])
                # print('disc_policy_acc', train_info['disc_policy_acc'])
        # update update time for policy
        self.policy._update_times += 1

        return train_info

    def train_offline(self, share_obs, obs, actions):
        loss = self.policy.train_offline(share_obs=share_obs, obs=obs, actions=actions)
        self.policy.optimizer.zero_grad()
        loss.backward()
        self.policy.optimizer.step()
        self.policy.scheduler.step()

        return loss.item()

    def prep_training(self):
        self.policy.train()

    def prep_rollout(self):
        self.policy.eval()
