import numpy as np
from collections import OrderedDict

import torch
import torch.optim as optim
from torch import nn
from torch import autograd
import torch.nn.functional as F

import rlkit.torch.pytorch_util as ptu
from rlkit.torch.core import np_to_pytorch_batch
from rlkit.torch.torch_base_algorithm import TorchBaseAlgorithm
from rlkit.data_management.path_builder import PathBuilder

from tqdm import  tqdm


class AdvIRL_LfO(TorchBaseAlgorithm):
    '''
        Depending on choice of reward function and size of replay
        buffer this will be:
            - AIRL
            - GAIL (without extra entropy term)
            - FAIRL
            - Discriminator Actor Critic
        
        I did not implement the reward-wrapping mentioned in
        https://arxiv.org/pdf/1809.02925.pdf though

        Features removed from v1.0:
            - gradient clipping
            - target disc (exponential moving average disc)
            - target policy (exponential moving average policy)
            - disc input noise
    '''
    def __init__(
        self,

        mode, # airl, gail, fairl, or sl
        inverse_mode, # MLE or MSE
        discriminator,
        policy_trainer,

        expert_replay_buffer,

        state_only=False,
        sas=False,
        qss=False,
        state_diff=False,

        disc_optim_batch_size=1024,
        policy_optim_batch_size=1024,
        policy_optim_batch_size_from_expert=0,

        num_update_loops_per_train_call=1,
        num_disc_updates_per_loop_iter=100,
        num_policy_updates_per_loop_iter=100,
        num_state_predictor_updates_per_loop_iter=100,
        num_inverse_dynamic_updates_per_loop_iter=100,
        num_pretrain_updates=20,
        pretrain_steps_per_epoch=5000,

        disc_lr=1e-3,
        disc_momentum=0.9,
        disc_optimizer_class=optim.Adam,

        state_predictor_lr=1e-3,
        state_predictor_momentum=0.0,
        state_predictor_optimizer_class=optim.Adam,

        inverse_dynamic_lr=1e-3,
        inverse_dynamic_momentum=0.0,
        inverse_dynamic_optimizer_class=optim.Adam,

        decay_ratio = 1.0,

        use_grad_pen=True,
        use_wgan = True,
        grad_pen_weight=10,

        rew_clip_min=None,
        rew_clip_max=None,

        **kwargs
    ):
        assert mode in ['airl', 'gail', 'fairl', 'sl'], 'Invalid adversarial irl algorithm!'
        assert inverse_mode in ['MSE', 'MLE'], 'Invalid bco algorithm!'
        if kwargs['wrap_absorbing']: raise NotImplementedError()
        super().__init__(**kwargs)

        self.mode = mode
        self.inverse_mode = inverse_mode
        self.state_only = state_only
        self.sas = sas
        self.qss = qss
        self.state_diff = state_diff

        self.expert_replay_buffer = expert_replay_buffer

        self.policy_trainer = policy_trainer
        self.policy_optim_batch_size = policy_optim_batch_size
        self.policy_optim_batch_size_from_expert = policy_optim_batch_size_from_expert
        
        self.discriminator = discriminator
        self.disc_optimizer = disc_optimizer_class(
            self.discriminator.parameters(),
            lr=disc_lr,
            betas=(disc_momentum, 0.999)
        )
        self.state_predictor_optimizer = state_predictor_optimizer_class(
            self.exploration_policy.state_predictor.parameters(),
            lr=state_predictor_lr,
            betas=(state_predictor_momentum, 0.999)
        )
        self.inverse_dynamic_optimizer = inverse_dynamic_optimizer_class(
            self.exploration_policy.inverse_dynamic.parameters(),
            lr=inverse_dynamic_lr,
            betas=(inverse_dynamic_momentum, 0.999)
        )
        self.inverse_dynamic_scheduler = optim.lr_scheduler.ExponentialLR(self.inverse_dynamic_optimizer, gamma=decay_ratio)

        # self.inverse_dynamic_scheduler = optim.lr_scheduler.StepLR(self.inverse_dynamic_optimizer, step_size=20, gamma=decay_ratio)

        self.state_predictor_scheduler = optim.lr_scheduler.ExponentialLR(self.state_predictor_optimizer, gamma=decay_ratio)

        # self.state_predictor_scheduler = optim.lr_scheduler.StepLR(self.state_predictor_optimizer, step_size=20, gamma=decay_ratio)

        self.disc_optim_batch_size = disc_optim_batch_size
        self.state_predictor_optim_batch_size = policy_optim_batch_size
        self.inverse_dynamic_optim_batch_size = policy_optim_batch_size

        self.pretrain_steps_per_epoch = pretrain_steps_per_epoch

        print('\n\nDISC MOMENTUM: %f\n\n' % disc_momentum)
        print('\n\nSTATE-PREDICTOR MOMENTUM: %f\n\n' % state_predictor_momentum)
        print('\n\nINVERSE-DYNAMIC MOMENTUM: %f\n\n' % inverse_dynamic_momentum)

        self.bce = nn.BCEWithLogitsLoss()
        self.bce_targets = torch.cat(
            [
                torch.ones(disc_optim_batch_size, 1),
                torch.zeros(disc_optim_batch_size, 1)
            ],
            dim=0
        )
        self.bce.to(ptu.device)
        self.bce_targets = self.bce_targets.to(ptu.device)
        
        self.use_grad_pen = use_grad_pen
        self.use_wgan = use_wgan
        self.grad_pen_weight = grad_pen_weight

        self.num_update_loops_per_train_call = num_update_loops_per_train_call
        self.num_disc_updates_per_loop_iter = num_disc_updates_per_loop_iter
        self.num_policy_updates_per_loop_iter = num_policy_updates_per_loop_iter
        self.num_state_predictor_updates_per_loop_iter = num_state_predictor_updates_per_loop_iter
        self.num_inverse_dynamic_updates_per_loop_iter = num_inverse_dynamic_updates_per_loop_iter
        self.num_pretrain_updates = num_pretrain_updates

        self.rew_clip_min = rew_clip_min
        self.rew_clip_max = rew_clip_max
        self.clip_min_rews = rew_clip_min is not None
        self.clip_max_rews = rew_clip_max is not None

        self.disc_eval_statistics = None
        self.policy_eval_statistics = None


    def get_batch(self, batch_size, from_expert, keys=None):
        if from_expert:
            buffer = self.expert_replay_buffer
        else:
            buffer = self.replay_buffer
        batch = buffer.random_batch(batch_size, keys=keys)
        batch = np_to_pytorch_batch(batch)
        return batch


    def _end_epoch(self):
        self.policy_trainer.end_epoch()
        self.disc_eval_statistics = None
        # self.state_predictor_scheduler.step()
        # self.inverse_dynamic_scheduler.step()
        super()._end_epoch()


    def evaluate(self, epoch):
        self.eval_statistics = OrderedDict()

        if self.policy_eval_statistics is not None:
            self.eval_statistics.update(self.policy_eval_statistics)
        if self.mode != 'sl':
            self.eval_statistics.update(self.disc_eval_statistics)
            self.eval_statistics.update(self.policy_trainer.get_eval_statistics())

        super().evaluate(epoch)

    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        print("Pretraining ...")
        self._current_path_builder = PathBuilder()
        observation = self._start_new_rollout()
        for _ in tqdm(range(self.num_pretrain_updates)):
            # sample data using a random policy
            for steps_this_epoch in range(self.pretrain_steps_per_epoch):
                action, agent_info = self._get_action_and_info(observation)
                if self.render: self.training_env.render()

                next_ob, raw_reward, terminal, env_info = (
                    self.training_env.step(action)
                )
                if self.no_terminal: terminal = False
                self._n_env_steps_total += 1

                reward = np.array([raw_reward])
                terminal = np.array([terminal])

                timeout = False
                if len(self._current_path_builder) >= (self.max_path_length - 1):
                    timeout = True
                timeout = np.array([timeout])

                self._handle_step(
                    observation,
                    action,
                    reward,
                    next_ob,
                    np.array([False]) if self.no_terminal else terminal,
                    timeout,
                    absorbing=np.array([0., 0.]),
                    agent_info=agent_info,
                    env_info=env_info,
                )
                if terminal[0]:
                    if self.wrap_absorbing:
                        raise NotImplementedError()
                    self._handle_rollout_ending()
                    observation = self._start_new_rollout()
                elif len(self._current_path_builder) >= self.max_path_length:
                    self._handle_rollout_ending()
                    observation = self._start_new_rollout()
                else:
                    observation = next_ob
            self._do_state_predictor_training(-1, True)
            self._do_inverse_dynamic_training(-1, False)

    def _do_training(self, epoch):
        for t in range(self.num_update_loops_per_train_call):
            if self.mode != 'sl':
                for _ in range(self.num_disc_updates_per_loop_iter):
                    self._do_reward_training(epoch)
                for _ in range(self.num_policy_updates_per_loop_iter):
                    self._do_policy_training(epoch)
            for _ in range(self.num_state_predictor_updates_per_loop_iter):
                self._do_state_predictor_training(epoch, True)
            for _ in range(self.num_inverse_dynamic_updates_per_loop_iter):
                self._do_inverse_dynamic_training(epoch, False)


    def _do_state_predictor_training(self, epoch, use_expert_buffer=True):
        '''
            Train the state predictor
        '''
        self.state_predictor_optimizer.zero_grad()

        batch = self.get_batch(
            self.state_predictor_optim_batch_size,
            keys=['observations', 'next_observations'],
            from_expert=use_expert_buffer
        )

        obs = batch['observations']
        next_obs = batch['next_observations']

        pred_obs = self.exploration_policy.state_predictor(obs)
        label_obs = next_obs
        if self.state_diff:
            label_obs = next_obs - obs
        squared_diff = (pred_obs - label_obs) ** 2
        loss = torch.sum(squared_diff, dim=-1).mean()
        if self.policy_eval_statistics is None:
            self.policy_eval_statistics = OrderedDict()
        self.policy_eval_statistics['State-Predictor-MSE'] = ptu.get_numpy(loss)
        # print("state predictor, loss", loss)
        loss.backward()
        self.state_predictor_optimizer.step()


    def _do_inverse_dynamic_training(self, epoch, use_expert_buffer=False):
        '''
            Train the inverse dynamic model
        '''
        self.inverse_dynamic_optimizer.zero_grad()

        batch = self.get_batch(
            self.inverse_dynamic_optim_batch_size,
            keys=['observations', 'actions', 'next_observations'],
            from_expert=use_expert_buffer
        )

        obs = batch['observations']
        acts = batch['actions']
        next_obs = batch['next_observations']

        if self.inverse_mode == 'MLE':
            log_prob = self.exploration_policy.inverse_dynamic.get_log_prob(obs, next_obs, acts)
            loss = -1.0 * log_prob.mean()
            if self.policy_eval_statistics is None:
                self.policy_eval_statistics = OrderedDict()
            self.policy_eval_statistics['Inverse-Dynamic-Log-Likelihood'] = ptu.get_numpy(-1.0 * loss)
        else:
            pred_acts = self.exploration_policy.inverse_dynamic(obs, next_obs)[0]
            squared_diff = (pred_acts - acts) ** 2
            loss = torch.sum(squared_diff, dim=-1).mean()
            if self.policy_eval_statistics is None:
                self.policy_eval_statistics = OrderedDict()
            self.policy_eval_statistics['Inverse-Dynamic-MSE'] = ptu.get_numpy(loss)
        # print("inverse, loss", loss)
        assert not torch.max(torch.isnan(loss)), "nan-inverse-dynamic-training, obs: {}, obs_prime: {}, acts: {}, log_std: {}".format(obs,obs_prime,acts,log_prob)
        loss.backward()
        # if torch.max(torch.isnan(torch.Tensor(list(self.exploration_policy.inverse_dynamic.parameters())))):
            # print("nan-inverse-dynamic-training, obs: {}, obs_prime: {}, acts: {}, log_std: {}".format(obs,obs_prime,acts,log_prob))
            # for name, parms in self.exploration_policy.inverse_dynamic.named_parameters():
                # print('-->name:', name, '-->grad_requirs:',parms.requires_grad, ' -->grad_value:',parms.grad)

        self.inverse_dynamic_optimizer.step()

    def _do_reward_training(self, epoch):
        '''
            Train the discriminator
        '''
        self.disc_optimizer.zero_grad()

        keys = ['observations']
        if self.state_only:
            keys.append('next_observations')
        if self.sas:
            keys.append('next_observations')
            keys.append('actions')
        else:
            keys.append('actions')
        expert_batch = self.get_batch(self.disc_optim_batch_size, True, keys)
        policy_batch = self.get_batch(self.disc_optim_batch_size, False, keys)

        if self.wrap_absorbing:
            pass
            # expert_obs = torch.cat([expert_obs, expert_batch['absorbing'][:, 0:1]], dim=-1)
            # policy_obs = torch.cat([policy_obs, policy_batch['absorbing'][:, 0:1]], dim=-1)
        
        expert_obs = expert_batch['observations']
        policy_obs = policy_batch['observations']

        expert_next_obs = expert_batch['next_observations']
        policy_next_obs = policy_batch['next_observations']

        expert_inputs = [expert_obs, expert_next_obs]
        policy_inputs = [policy_obs, policy_next_obs]

        if self.sas:
            expert_acts = self.exploration_policy.inverse_dynamic(expert_obs, expert_next_obs)[0]
            policy_acts = policy_batch['actions']

            expert_inputs = [expert_obs, expert_acts, expert_next_obs]
            policy_inputs = [policy_obs, policy_acts, policy_next_obs]

        expert_disc_input = torch.cat(expert_inputs, dim=1)
        policy_disc_input = torch.cat(policy_inputs, dim=1)

        if self.use_wgan:
            expert_logits = self.discriminator(expert_disc_input)
            policy_logits = self.discriminator(policy_disc_input)

            disc_ce_loss = -torch.sum(expert_logits) + torch.sum(policy_logits)
        else:
            disc_input = torch.cat([expert_disc_input, policy_disc_input], dim=0)

            disc_logits = self.discriminator(disc_input)
            disc_preds = (disc_logits > 0).type(disc_logits.data.type())
            disc_ce_loss = self.bce(disc_logits, self.bce_targets)
            accuracy = (disc_preds == self.bce_targets).type(torch.FloatTensor).mean()

        if self.use_grad_pen:
            eps = ptu.rand(expert_obs.size(0), 1)
            eps.to(ptu.device)
            
            interp_obs = eps*expert_disc_input + (1-eps)*policy_disc_input
            interp_obs = interp_obs.detach()
            interp_obs.requires_grad_(True)

            gradients = autograd.grad(
                outputs=self.discriminator(interp_obs).sum(),
                inputs=[interp_obs],
                create_graph=True, retain_graph=True, only_inputs=True
            )
            total_grad = gradients[0]
            
            # GP from Gulrajani et al.
            gradient_penalty = ((total_grad.norm(2, dim=1) - 1) ** 2).mean()
            disc_grad_pen_loss = gradient_penalty * self.grad_pen_weight

            # # GP from Mescheder et al.
            # gradient_penalty = (total_grad.norm(2, dim=1) ** 2).mean()
            # disc_grad_pen_loss = gradient_penalty * 0.5 * self.grad_pen_weight
        else:
            disc_grad_pen_loss = 0.0

        disc_total_loss = disc_ce_loss + disc_grad_pen_loss
        assert not torch.max(
            torch.isnan(disc_total_loss)), "nan-reward-training, disc_ce_loss: {}, disc_grad_pen_loss: {}".format(
            disc_ce_loss, disc_grad_pen_loss)
        disc_total_loss.backward()
        self.disc_optimizer.step()

        """
        Save some statistics for eval
        """
        if self.disc_eval_statistics is None:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.disc_eval_statistics = OrderedDict()
            
            self.disc_eval_statistics['Disc CE Loss'] = np.mean(ptu.get_numpy(disc_ce_loss))
            if not self.use_wgan:
                self.disc_eval_statistics['Disc Acc'] = np.mean(ptu.get_numpy(accuracy))
            if self.use_wgan:
                self.disc_eval_statistics['Expert D Logits'] = np.mean(ptu.get_numpy(expert_logits))
                self.disc_eval_statistics['Policy D Logits'] = np.mean(ptu.get_numpy(policy_logits))
            if self.use_grad_pen:
                self.disc_eval_statistics['Grad Pen'] = np.mean(ptu.get_numpy(gradient_penalty))
                self.disc_eval_statistics['Grad Pen W'] = np.mean(self.grad_pen_weight)


    def _do_policy_training(self, epoch):
        if self.policy_optim_batch_size_from_expert > 0:
            policy_batch_from_policy_buffer = self.get_batch(
                self.policy_optim_batch_size - self.policy_optim_batch_size_from_expert, False)
            policy_batch_from_expert_buffer = self.get_batch(
                self.policy_optim_batch_size_from_expert, True)
            policy_batch = {}
            for k in policy_batch_from_policy_buffer:
                policy_batch[k] = torch.cat(
                    [
                        policy_batch_from_policy_buffer[k],
                        policy_batch_from_expert_buffer[k]
                    ],
                    dim=0
                )
        else:
            policy_batch = self.get_batch(self.policy_optim_batch_size, False)
        
        obs = policy_batch['observations']
        next_obs = policy_batch['next_observations']

        policy_inputs = [obs, next_obs]

        if self.sas:
            acts = policy_batch['actions']
            policy_inputs = [obs, acts, next_obs]

        if self.wrap_absorbing:
            pass
            # obs = torch.cat([obs, policy_batch['absorbing'][:, 0:1]], dim=-1)
        else:
            self.discriminator.eval()
            disc_input = torch.cat(policy_inputs, dim=1)
            disc_logits = self.discriminator(disc_input).detach()
            self.discriminator.train()

        # compute the reward using the algorithm
        if self.mode == 'airl':
            # If you compute log(D) - log(1-D) then you just get the logits
            policy_batch['rewards'] = disc_logits
        elif self.mode == 'gail':
            policy_batch['rewards'] = F.softplus(disc_logits, beta=-1)
        else: # fairl
            policy_batch['rewards'] = torch.exp(disc_logits)*(-1.0*disc_logits)
        
        if self.clip_max_rews:
            policy_batch['rewards'] = torch.clamp(policy_batch['rewards'], max=self.rew_clip_max)
        if self.clip_min_rews:
            policy_batch['rewards'] = torch.clamp(policy_batch['rewards'], min=self.rew_clip_min)
        
        # policy optimization step
        self.policy_trainer.train_step(policy_batch, qss=self.qss)

        self.disc_eval_statistics['Disc Rew Mean'] = np.mean(ptu.get_numpy(policy_batch['rewards']))
        self.disc_eval_statistics['Disc Rew Std'] = np.std(ptu.get_numpy(policy_batch['rewards']))
        self.disc_eval_statistics['Disc Rew Max'] = np.max(ptu.get_numpy(policy_batch['rewards']))
        self.disc_eval_statistics['Disc Rew Min'] = np.min(ptu.get_numpy(policy_batch['rewards']))
    
    
    @property
    def networks(self):
        return [self.discriminator] + self.policy_trainer.networks + [self.policy_trainer.policy.state_predictor,self.policy_trainer.policy.inverse_dynamic]


    def get_epoch_snapshot(self, epoch):
        snapshot = super().get_epoch_snapshot(epoch)
        snapshot.update(disc=self.discriminator)
        snapshot.update(dict(
            qf1=self.policy_trainer.qf1,
            qf2=self.policy_trainer.qf2,
            state_predictor=self.policy_trainer.policy.state_predictor,
            inverse_dynamic=self.policy_trainer.policy.inverse_dynamic,
            vf=self.policy_trainer.vf,
            target_vf=self.policy_trainer.target_vf,
        ))
        return snapshot


    def to(self, device):
        self.bce.to(ptu.device)
        self.bce_targets = self.bce_targets.to(ptu.device)
        super().to(device)
