import numpy as np
import torch
import torch.nn as nn

import onpolicy.algorithms.gail.gail_utils as gail_utils
from onpolicy.algorithms.r_mappo.r_mappo import R_MAPPO
from onpolicy.algorithms.utils.util import check


class GAIL(R_MAPPO):
    def __init__(self,
                 args,
                 policy,
                 device=torch.device("cpu")):
        super().__init__(args, policy, device=device)

        # diffrl
        self.reference_entropy_coef = args.reference_entropy_coef

        # gail
        self.update_actor_critic = not args.eval_retrain_discriminator
        self.gail_batch_size = args.gail_batch_size
        self.gradient_penalty_coef = args.gradient_penalty_coef
        self.window_size = args.gail_window_size
        self.gail_task = args.gail_task
        self.task_reward_coef = args.task_reward_coef
        self.gail_reward_coef = args.gail_reward_coef
        self.obs_func = gail_utils.build_obs_func(task=self.gail_task, 
                                                  window_size=self.window_size,
                                                  algorithm_name=args.algorithm_name,
                                                  scenario_name=args.scenario_name)

        self.discriminator = gail_utils.build_discriminator(obs_dim=args.gail_obs_dim, 
                                                            device=device, 
                                                            window_size=self.window_size)
        with open('log.txt','a') as f:
            f.write('dataset_path='+str(args.dataset_path))
        self.dataset = gail_utils.build_dataset(task=self.gail_task,
                                                window_size=self.window_size,
                                                dataset_path=args.dataset_path,
                                                scenario_name=args.scenario_name,
                                                extend_traj=getattr(args, 'extend_traj', False),
                                                extend_tarj_length=getattr(args, 'extend_traj_length', 100))
        
        # bc
        self.behavior_cloning = args.behavior_cloning
        self.behavior_cloning_coef = args.behavior_cloning_coef
    
    def compute_entropy_loss(self, dist_entropy):
        if self.algorithm_name in ["diff-gail", "diff-infogail"]:
            # assert (dist_entropy.shape == (2,)), (dist_entropy.shape)
            return dist_entropy[:-1].mean(dim=0) * self.entropy_coef + dist_entropy[-1] * self.reference_entropy_coef
        return super().compute_entropy_loss(dist_entropy)

    def compute_gail_reward(self, obs, actions, masks):
        obs, actions, on_reset = obs[:-1, :, 0, :], actions[:, :, 0, :], 1 - masks[:-1, :, 0, :]
        T, B = obs.shape[:2]
        # T x B x ... -> W x ((T - W + 1) x B) x ...
        x_fake, mask = self.obs_func.process_fake_data(check(obs).to(**self.tpdv), check(actions).to(**self.tpdv), check(on_reset).to(**self.tpdv))
        with torch.no_grad():
            d_fake =  self.discriminator.compute(x_fake)
        d_fake = d_fake.reshape(T - self.window_size + 1, B)
        d_fake = torch.cat([torch.zeros(self.window_size - 1, B).to(device=self.device, dtype=torch.float32), d_fake], 0)
        gail_reward = (-(1. -  d_fake.unsqueeze(-1) + 1e-6).log()) * mask
        return gail_reward.cpu().numpy()[:, :, np.newaxis, :]
    
    def update_gail(self, obs, actions, masks):
        obs, actions, on_reset = obs[:-1, :, 0, :], actions[:, :, 0, :], 1 - masks[:-1, :, 0, :]
        obs, actions, on_reset = check(obs).to(**self.tpdv), check(actions).to(**self.tpdv), check(on_reset).to(**self.tpdv)

        self.discriminator.optimizer.zero_grad()
        x_real = self.obs_func.process_real_data(self.dataset.sample(batch_size=self.gail_batch_size), 
                                                 device=self.device)
        x_fake, _ = self.obs_func.process_fake_data(*gail_utils.sample_from_rollouts(obs,
                                                                                    actions,
                                                                                    on_reset, 
                                                                                    batch_size=self.gail_batch_size,
                                                                                    window_size=self.window_size))
    
        x_real.requires_grad_()
        x_fake.requires_grad_()

        d_x_real = self.discriminator.compute(x_real)
        d_x_fake = self.discriminator.compute(x_fake)

        
        dloss_real = - ((d_x_real + 1e-6).log()).mean()

        dloss_real.backward(retain_graph=True)
        reg_real = self.gradient_penalty_coef * gail_utils.compute_grad2(d_x_real, x_real, batch_size=self.gail_batch_size).mean()
        reg_real.backward()

        dloss_fake = - (1. - d_x_fake + 1e-6).log().mean()

        dloss_fake.backward()

        if self.max_grad_norm is not None:
            nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.max_grad_norm)
        
        self.discriminator.optimizer.step()

        disc_info = dict(dloss_real=dloss_real.item(), 
                         d_x_real=d_x_real.mean().item(),
                         dloss_fake=dloss_fake.item(), 
                         d_x_fake=d_x_fake.mean().item(),
                         reg_real=reg_real.item(), 
                         dloss = (dloss_real + dloss_fake + reg_real).item())
            
        return disc_info

    def compute_extra_actor_loss(self):
        if self.behavior_cloning:
            observations, actions = self.dataset.sample(batch_size=self.gail_batch_size)

            bs = self.gail_batch_size
            observations = observations[-1].to(self.device) # bs, d
            actions = actions[-1].to(self.device) # bs, d

            action_log_probs, _ = self.policy.actor.evaluate_actions(observations, np.zeros((bs, 1)), actions, np.ones((bs, 1)))

            # behavior cloning loss = negative log likelihood of actions
            bc_loss = - action_log_probs.mean()

            return self.behavior_cloning_coef * bc_loss, dict(bc_loss=bc_loss)
        return 0, dict()

    def train(self, buffer, update_actor = True):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.

        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        if self._use_popart or self._use_valuenorm:
            advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1])
        else:
            advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
        advantages_copy = advantages.copy()
        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
        mean_advantages = np.nanmean(advantages_copy)
        std_advantages = np.nanstd(advantages_copy)
        advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)
        

        train_info = {}

        train_info['value_loss'] = 0
        train_info['policy_loss'] = 0
        train_info['dist_entropy'] = 0
        train_info['actor_grad_norm'] = 0
        train_info['critic_grad_norm'] = 0
        train_info['ratio'] = 0

        for _ in range(self.ppo_epoch):
            if self._use_recurrent_policy:
                data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
            elif self._use_naive_recurrent:
                data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)
            else:
                data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)

            for sample in data_generator:
                if self.update_actor_critic:
                    value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights, extra_loss_dict \
                        = self.ppo_update(sample, update_actor)

                    train_info['value_loss'] += value_loss.item()
                    train_info['policy_loss'] += policy_loss.item()
                    train_info['dist_entropy'] += dist_entropy.sum().item()
                    train_info['actor_grad_norm'] += actor_grad_norm
                    train_info['critic_grad_norm'] += critic_grad_norm
                    train_info['ratio'] += imp_weights.mean()
                    for k, v in extra_loss_dict.items():
                        if k not in train_info:
                            train_info[k] = 0
                        train_info[k] += v.item()
            
                disc_info_ = self.update_gail(buffer.share_obs, buffer.actions, buffer.masks)
                for k, v in disc_info_.items():
                    if k not in train_info:
                        train_info[k] = 0
                    train_info[k] += v

        num_updates = self.ppo_epoch * self.num_mini_batch

        for k in train_info.keys():
            train_info[k] /= num_updates

        return train_info

    def prep_training(self):
        if self.update_actor_critic:
            if hasattr(self.policy, 'actor'):
                self.policy.actor.train()
            if hasattr(self.policy, 'critic'):
                self.policy.critic.train()
        self.discriminator.net.train()

    def prep_rollout(self):
        if hasattr(self.policy, 'actor'):
            self.policy.actor.eval()
        if hasattr(self.policy, 'critic'):
            self.policy.critic.eval()
        self.discriminator.net.eval()