import numpy as np
import torch
import torch.nn as nn

import onpolicy.algorithms.gail.gail_utils as gail_utils
from onpolicy.algorithms.r_mappo.r_mappo import R_MAPPO
from onpolicy.algorithms.utils.util import check

class DIAYN(R_MAPPO):
    def __init__(self,
                 args,
                 policy,
                 device=torch.device("cpu")):
        super().__init__(args, policy, device=device)

        # infogail
        self.gail_batch_size = args.gail_batch_size
        self.window_size = args.gail_window_size
        self.z_latent_dim = args.z_latent_dim
        self.info_task = args.info_task
        self.info_obs_func = gail_utils.build_obs_func(task=self.info_task, 
                                                  window_size=self.window_size,
                                                  algorithm_name=args.algorithm_name,
                                                  scenario_name=args.scenario_name)

        self.info_decoder = gail_utils.build_info_decoder(z_latent_dim=self.z_latent_dim,
                                                          obs_dim=args.gail_obs_dim, 
                                                          device=device, 
                                                          window_size=self.window_size)

    def compute_info_reward(self, obs, actions, masks):
        obs, actions, on_reset = obs[:-1, :, 0, :], actions[:, :, 0, :], 1 - masks[:-1, :, 0, :]
        T, B = obs.shape[:2]
        # T x B x ... -> W x ((T - W + 1) x B) x ...
        z = obs[self.window_size-1:, ..., :self.z_latent_dim]
        z = check(z.reshape((T - self.window_size + 1) * B, *z.shape[2:])).to(**self.tpdv)
        obs = obs[..., self.z_latent_dim:]
        x, mask = self.info_obs_func.process_fake_data(check(obs).to(**self.tpdv), check(actions).to(**self.tpdv), check(on_reset).to(**self.tpdv))
        with torch.no_grad():
            z_log_prob = self.info_decoder.compute(x, z)
        reward = z_log_prob.reshape(T - self.window_size + 1, B)
        reward = torch.cat([torch.zeros(self.window_size - 1, B).to(device=self.device, dtype=torch.float32), reward], 0)
        reward = reward.unsqueeze(-1) * mask
        return reward.cpu().numpy()[:, :, np.newaxis, :]

    def update_gail(self, obs, actions, masks):
        return super().update_gail(obs[..., self.z_latent_dim:], actions, masks)

    def update_info_decoder(self, obs, actions, masks):
        obs, actions, on_reset = obs[:-1, :, 0, :], actions[:, :, 0, :], 1 - masks[:-1, :, 0, :]
        obs, actions, on_reset = check(obs).to(**self.tpdv), check(actions).to(**self.tpdv), check(on_reset).to(**self.tpdv)

        self.info_decoder.optimizer.zero_grad()

        obs, actions, on_reset = gail_utils.sample_from_rollouts(obs,
                                                                 actions,
                                                                 on_reset, 
                                                                batch_size=self.gail_batch_size,
                                                                 window_size=self.window_size)
        z = obs[-1, ..., :self.z_latent_dim]
        obs = obs[..., self.z_latent_dim:]

        x, _ = self.info_obs_func.process_fake_data(obs, actions, on_reset)

        z_log_prob = self.info_decoder.compute(x, z)

        info_loss = -z_log_prob.mean()
        info_loss.backward()

        self.info_decoder.optimizer.step()

        z_info = dict(z_log_prob=z_log_prob.mean().item())
        return z_info
    
    def train(self, buffer, update_actor = True):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.

        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        if self._use_popart or self._use_valuenorm:
            advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1])
        else:
            advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
        advantages_copy = advantages.copy()
        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
        mean_advantages = np.nanmean(advantages_copy)
        std_advantages = np.nanstd(advantages_copy)
        advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)
        
        train_info = {}

        train_info['value_loss'] = 0
        train_info['policy_loss'] = 0
        train_info['dist_entropy'] = 0
        train_info['actor_grad_norm'] = 0
        train_info['critic_grad_norm'] = 0
        train_info['ratio'] = 0

        for _ in range(self.ppo_epoch):
            if self._use_recurrent_policy:
                data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
            elif self._use_naive_recurrent:
                data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)
            else:
                data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)

            for sample in data_generator:

                value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights \
                    = self.ppo_update(sample, update_actor)

                train_info['value_loss'] += value_loss.item()
                train_info['policy_loss'] += policy_loss.item()
                train_info['dist_entropy'] += dist_entropy.sum().item()
                train_info['actor_grad_norm'] += actor_grad_norm
                train_info['critic_grad_norm'] += critic_grad_norm
                train_info['ratio'] += imp_weights.mean()
            
                z_info = self.update_info_decoder(buffer.obs, buffer.actions, buffer.masks)
                for k, v in z_info.items():
                    if k not in train_info:
                        train_info[k] = 0
                    train_info[k] += v

        num_updates = self.ppo_epoch * self.num_mini_batch

        for k in train_info.keys():
            train_info[k] /= num_updates

        return train_info

    def prep_training(self):
        super().prep_training()
        self.info_decoder.net.train()

    def prep_rollout(self):
        super().prep_rollout()
        self.info_decoder.net.eval()
