import numpy as np
from collections import OrderedDict

import torch
import torch.optim as optim
from torch import nn
from torch import autograd
import torch.nn.functional as F

import torch
import rlkit.torch.pytorch_util as ptu
from rlkit.torch.core import np_to_pytorch_batch
from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
from rlkit.torch.torch_base_algorithm import TorchBaseAlgorithm
from rlkit.data_management.path_builder import PathBuilder

from tqdm import tqdm


class AdvIRL_LfO(TorchBaseAlgorithm):
    '''
        Depending on choice of reward function and size of replay
        buffer this will be:
            - AIRL
            - GAIL (without extra entropy term)
            - FAIRL
            - Discriminator Actor Critic

        I did not implement the reward-wrapping mentioned in
        https://arxiv.org/pdf/1809.02925.pdf though

        Features removed from v1.0:
            - gradient clipping
            - target disc (exponential moving average disc)
            - target policy (exponential moving average policy)
            - disc input noise
    '''
    def __init__(
        self,

        mode,  # airl, gail, fairl, or sl
        inverse_mode,  # MLE or MSE
        discriminator,
        policy_trainer,

        expert_replay_buffer,

        state_only=False,
        sas=False,
        qss=False,
        sss=False,
        state_diff=False,
        union=False,
        union_sp=False,
        multi_step=False,
        reward_penelty=False,
        inv_buffer=False,
        update_weight=False,
        penelty_weight=1.0,
        step_num=1,

        disc_optim_batch_size=1024,
        policy_optim_batch_size=1024,
        policy_optim_batch_size_from_expert=0,

        num_update_loops_per_train_call=1,
        num_disc_updates_per_loop_iter=100,
        num_policy_updates_per_loop_iter=100,
        num_state_predictor_updates_per_loop_iter=100,
        num_inverse_dynamic_updates_per_loop_iter=100,
        num_pretrain_updates=20,
        pretrain_steps_per_epoch=5000,

        disc_lr=1e-3,
        disc_momentum=0.9,
        disc_optimizer_class=optim.Adam,

        state_predictor_lr=1e-3,
        state_predictor_alpha=20,
        state_predictor_momentum=0.0,
        state_predictor_optimizer_class=optim.Adam,

        inverse_dynamic_lr=1e-3,
        inverse_dynamic_beta=0.5,
        inverse_dynamic_momentum=0.0,
        inverse_dynamic_optimizer_class=optim.Adam,

        decay_ratio=1.0,

        use_grad_pen=True,
        use_wgan=True,
        grad_pen_weight=10,

        rew_clip_min=-10,
        rew_clip_max=10,

        **kwargs
    ):
        assert mode in ['airl', 'gail', 'fairl', 'gail2', 'sl', 'sl-test'], 'Invalid adversarial irl algorithm!'
        assert inverse_mode in ['MSE', 'MLE', 'MAE'], 'Invalid bco algorithm!'
        # if kwargs['wrap_absorbing']: raise NotImplementedError()
        super().__init__(**kwargs)

        self.mode = mode
        self.inverse_mode = inverse_mode
        self.state_only = state_only
        self.sas = sas
        self.qss = qss
        self.sss = sss
        self.state_diff = state_diff
        self.union = union
        self.union_sp = union_sp
        self.reward_penelty = reward_penelty
        self.penelty_weight = penelty_weight
        self.inv_buffer = inv_buffer
        self.update_weight = update_weight
        self.multi_step = multi_step
        self.step_num = step_num

        if self.mode in ['sl', 'sl-test']:
            self.union = False
            self.union_sp = False

        if self.union_sp:  # gail only train state predictor
            self.inverse_dynamic_beta = 0
            assert mode in ['airl', 'gail', 'fairl', 'gail2']

        self.expert_replay_buffer = expert_replay_buffer
        self.inv_replay_buffer = self.replay_buffer
        if self.inv_buffer:
            self.inv_replay_buffer = EnvReplayBuffer(
                1000000,
                self.env,
                random_seed=np.random.randint(10000)
            )
        self.target_state_predictor = None
        if self.multi_step:
            self.target_state_predictor = self.exploration_policy.state_predictor.copy()

        self.policy_trainer = policy_trainer
        self.policy_optim_batch_size = policy_optim_batch_size
        self.policy_optim_batch_size_from_expert = policy_optim_batch_size_from_expert

        self.discriminator = discriminator
        self.disc_optimizer = disc_optimizer_class(
            self.discriminator.parameters(),
            lr=disc_lr,
            betas=(disc_momentum, 0.999)
        )
        self.state_predictor_optimizer = state_predictor_optimizer_class(
            self.exploration_policy.state_predictor.parameters(),
            lr=state_predictor_lr,
            betas=(state_predictor_momentum, 0.999)
        )
        self.inverse_dynamic_optimizer = inverse_dynamic_optimizer_class(
            self.exploration_policy.inverse_dynamic.parameters(),
            lr=inverse_dynamic_lr,
            betas=(inverse_dynamic_momentum, 0.999)
        )
        self.state_predictor_alpha = state_predictor_alpha
        self.inverse_dynamic_beta = inverse_dynamic_beta

        self.inverse_dynamic_scheduler = optim.lr_scheduler.ExponentialLR(self.inverse_dynamic_optimizer, gamma=decay_ratio)

        self.state_predictor_scheduler = optim.lr_scheduler.ExponentialLR(self.state_predictor_optimizer, gamma=decay_ratio)

        self.disc_optim_batch_size = disc_optim_batch_size
        self.state_predictor_optim_batch_size = policy_optim_batch_size
        self.inverse_dynamic_optim_batch_size = policy_optim_batch_size

        self.pretrain_steps_per_epoch = pretrain_steps_per_epoch

        print('\n\nDISC MOMENTUM: %f\n\n' % disc_momentum)
        print('\n\nSTATE-PREDICTOR MOMENTUM: %f\n\n' % state_predictor_momentum)
        print('\n\nINVERSE-DYNAMIC MOMENTUM: %f\n\n' % inverse_dynamic_momentum)
        if self.update_weight:
            print('\n\nUPDATE WEIGHT!\n\n')
        if self.union_sp:
            print('\n\nUNION STATE PREDICTOR!\n\n')
        if self.union:
            print('\n\nUNION TRAINING!\n\n')
        if self.reward_penelty:
            print('\n\nREWARD PENELTY!\n\n')
        if self.multi_step:
            print('\n\nMULTI STEP - {}!\n\n'.format(self.step_num))

        print('\n\nNum_inverse_dynamic_updates_per_loop_iter: %f\n\n' % num_inverse_dynamic_updates_per_loop_iter)

        self.bce = nn.BCEWithLogitsLoss()
        self.bce_targets = torch.cat(
            [
                torch.ones(disc_optim_batch_size, 1),
                torch.zeros(disc_optim_batch_size, 1)
            ],
            dim=0
        )
        self.bce.to(ptu.device)
        self.bce_targets = self.bce_targets.to(ptu.device)
        
        self.use_grad_pen = use_grad_pen
        self.use_wgan = use_wgan
        self.grad_pen_weight = grad_pen_weight

        self.num_update_loops_per_train_call = num_update_loops_per_train_call
        self.num_disc_updates_per_loop_iter = num_disc_updates_per_loop_iter
        self.num_policy_updates_per_loop_iter = num_policy_updates_per_loop_iter
        self.num_state_predictor_updates_per_loop_iter = num_state_predictor_updates_per_loop_iter
        self.num_inverse_dynamic_updates_per_loop_iter = num_inverse_dynamic_updates_per_loop_iter
        self.num_pretrain_updates = num_pretrain_updates

        self.rew_clip_min = rew_clip_min
        self.rew_clip_max = rew_clip_max
        self.clip_min_rews = rew_clip_min is not None
        self.clip_max_rews = rew_clip_max is not None

        self.disc_eval_statistics = None
        self.policy_eval_statistics = None


    def get_batch(self, batch_size, from_expert, from_inv=False, keys=None, multi_step=False, step_num=1):
        if from_expert:
            buffer = self.expert_replay_buffer
        else:
            if from_inv:
                buffer = self.inv_replay_buffer
            else:
                buffer = self.replay_buffer

        batch = buffer.random_batch(batch_size, keys=keys, multi_step=multi_step, step_num=step_num)
        batch = np_to_pytorch_batch(batch)
        return batch


    def _end_epoch(self):
        self.policy_trainer.end_epoch()
        self.disc_eval_statistics = None
        # self.state_predictor_scheduler.step()
        # self.inverse_dynamic_scheduler.step()
        if self.update_weight:
            self.state_predictor_alpha *= 1.1
            self.inverse_dynamic_beta *= 1.1
            self.state_predictor_alpha = min(10.0, self.state_predictor_alpha)
            self.inverse_dynamic_beta = min(10.0, self.inverse_dynamic_beta)
        super()._end_epoch()


    def evaluate(self, epoch):
        self.eval_statistics = OrderedDict()

        if self.policy_eval_statistics is not None:
            self.eval_statistics.update(self.policy_eval_statistics)
        if 'sl' not in self.mode:
            self.eval_statistics.update(self.disc_eval_statistics)
            self.eval_statistics.update(self.policy_trainer.get_eval_statistics())

        super().evaluate(epoch, pred_obs=True)

    def pretrain(self):
        """
        Do anything before the main training phase.
        """
        print("Pretraining ...")
        self._current_path_builder = PathBuilder()
        observation = self._start_new_rollout()
        pred_obs_prime = None
       
        for _ in tqdm(range(self.num_pretrain_updates)):
            # sample data using a random policy
            
            for steps_this_epoch in range(self.pretrain_steps_per_epoch):
                pred_obs_prime, action, agent_info = self._get_action_and_info(observation, True)
                if self.render: self.training_env.render()

                next_ob, raw_reward, terminal, env_info = (
                    self.training_env.step(action)
                )
                if self.no_terminal: terminal = False
                self._n_env_steps_total += 1

                reward = np.array([raw_reward])
                terminal = np.array([terminal])

                timeout = False
                if len(self._current_path_builder) >= (self.max_path_length - 1):
                    timeout = True
                timeout = np.array([timeout])

                self._handle_step(
                    observation,
                    action,
                    reward,
                    next_ob,
                    np.array([False]) if (self.no_terminal or self.wrap_absorbing) else terminal,
                    timeout,
                    pred_obs=pred_obs_prime,
                    absorbing=np.array([0., 0.]),
                    agent_info=agent_info,
                    env_info=env_info,
                )
                if terminal[0]:
                    if self.wrap_absorbing:
                        # raise NotImplementedError()
                        '''
                        If we wrap absorbing states, two additional
                        transitions must be added: (s_T, s_abs) and
                        (s_abs, s_abs). In Disc Actor Critic paper
                        they make s_abs be a vector of 0s with last
                        dim set to 1. Here we are going to add the following:
                        ([next_ob,0], random_action, [next_ob, 1]) and
                        ([next_ob,1], random_action, [next_ob, 1])
                        This way we can handle varying types of terminal states.
                        '''
                        # next_ob is the absorbing state
                        # for now just taking the previous action
                        self._handle_step(
                            next_ob,
                            # action,
                            self.training_env.action_space.sample(),
                            # the reward doesn't matter
                            reward,
                            # next_ob,
                            np.zeros_like(next_ob),
                            np.array([False]),
                            timeout,
                            pred_obs=pred_obs_prime,
                            absorbing=np.array([0.0, 1.0]),
                            agent_info=agent_info,
                            env_info=env_info
                        )
                        self._handle_step(
                            # next_ob,
                            np.zeros_like(next_ob),
                            # action,
                            self.training_env.action_space.sample(),
                            # the reward doesn't matter
                            reward,
                            # next_ob,
                            np.zeros_like(next_ob),
                            np.array([False]),
                            timeout,
                            pred_obs=pred_obs_prime,
                            absorbing=np.array([1.0, 1.0]),
                            agent_info=agent_info,
                            env_info=env_info
                        )
                    self._handle_rollout_ending()
                    observation = self._start_new_rollout()
                elif len(self._current_path_builder) >= self.max_path_length:
                    self._handle_rollout_ending()
                    observation = self._start_new_rollout()
                else:
                    observation = next_ob
            self._do_state_predictor_training(-1, True)
            self._do_inverse_dynamic_training(-1, False)

    def _do_training(self, epoch):
        for t in range(self.num_update_loops_per_train_call):
            if 'sl' not in self.mode:
                for _ in range(self.num_disc_updates_per_loop_iter):
                    self._do_reward_training(epoch)
                for _ in range(self.num_policy_updates_per_loop_iter):
                    self._do_policy_training(epoch)
            if not self.union:
                if not self.union_sp:  # union sp do not train state predictor
                    for _ in range(self.num_state_predictor_updates_per_loop_iter):
                        self._do_state_predictor_training(epoch, True)
                for _ in range(self.num_inverse_dynamic_updates_per_loop_iter):
                    if self.mode == 'sl-test':
                        self._do_inverse_dynamic_training(epoch, True)
                    if (self.mode == 'sl') or (self.union_sp):
                        self._do_inverse_dynamic_training(epoch, False)


    def _do_state_predictor_training(self, epoch, use_expert_buffer=True):
        '''
            Train the state predictor
        '''
        self.state_predictor_optimizer.zero_grad()

        exp_keys = ['observations', 'next_observations']
        for i in np.arange(1,self.step_num+1):
            exp_keys.append('next{}_observations'.format(i))
        exp_batch = self.get_batch(
            self.state_predictor_optim_batch_size,
            keys=exp_keys,
            from_expert=use_expert_buffer,
            multi_step = self.multi_step,
            step_num=self.step_num
        )

        agent_batch = self.get_batch(
            self.state_predictor_optim_batch_size,
            keys=['observations', 'next_observations'],
            from_expert=False,
            multi_step=self.multi_step
        )

        obs = exp_batch['observations']
        next_obs = exp_batch['next_observations']

        agent_obs = agent_batch['observations']
        agent_next_obs = agent_batch['next_observations']

        pred_obs = self.exploration_policy.state_predictor(obs)
        agent_pred_obs = self.exploration_policy.state_predictor(agent_obs)

        label_obs = next_obs
        if self.state_diff:
            label_obs = next_obs - obs
        squared_diff = (pred_obs - label_obs) ** 2
        loss = torch.sum(squared_diff, dim=-1)
        if self.multi_step:
            next_pred_obs = pred_obs
            for i in np.arange(1, self.step_num + 1):
                pred_obs_use = next_pred_obs
                next_i_obs = exp_batch['next{}_observations'.format(i)]
                next_pred_obs = self.target_state_predictor(pred_obs_use)
                squared_diff_2 = (next_pred_obs - next_i_obs) ** 2
                loss += torch.sum(squared_diff_2, dim=-1)
            # loss = loss / (self.step_num + 1)
        loss = loss.mean()
        if self.policy_eval_statistics is None:
            self.policy_eval_statistics = OrderedDict()
        self.policy_eval_statistics['State-Pred-Expt-MSE'] = ptu.get_numpy(loss)

        agent_label_obs = agent_next_obs
        if self.state_diff:
            agent_label_obs = agent_next_obs - agent_obs
        agent_squared_diff = (agent_pred_obs - agent_label_obs) ** 2
        agent_loss = torch.sum(agent_squared_diff, dim=-1).mean()
        self.policy_eval_statistics['State-Pred-Real-MSE'] = ptu.get_numpy(agent_loss)

        # print("state predictor, loss", loss)
        loss.backward()
        self.state_predictor_optimizer.step()

        if self.multi_step:
            ptu.copy_model_params_from_to(self.exploration_policy.state_predictor, self.target_state_predictor)


    def _do_inverse_dynamic_training(self, epoch, use_expert_buffer=False):
        '''
            Train the inverse dynamic model
        '''
        self.inverse_dynamic_optimizer.zero_grad()

        batch = self.get_batch(
            self.inverse_dynamic_optim_batch_size,
            keys=['observations', 'actions', 'next_observations'],
            from_expert=use_expert_buffer,
            from_inv=True,
        )

        obs = batch['observations']
        acts = batch['actions']
        next_obs = batch['next_observations']

        if self.inverse_mode == 'MLE':
            log_prob = self.exploration_policy.inverse_dynamic.get_log_prob(obs, next_obs, acts)
            loss = -1.0 * log_prob.mean()
            if self.policy_eval_statistics is None:
                self.policy_eval_statistics = OrderedDict()
            self.policy_eval_statistics['Inverse-Dynamic-Log-Likelihood'] = ptu.get_numpy(-1.0 * loss)
        elif self.inverse_mode == 'MSE':
            pred_acts = self.exploration_policy.inverse_dynamic(obs, next_obs)[0]
            squared_diff = (pred_acts - acts) ** 2
            loss = torch.sum(squared_diff, dim=-1).mean()
            if self.policy_eval_statistics is None:
                self.policy_eval_statistics = OrderedDict()
            self.policy_eval_statistics['Inverse-Dynamic-MSE'] = ptu.get_numpy(loss)
        elif self.inverse_mode == 'MAE':
            raise NotImplementedError

        # print("inverse, loss", loss)
        assert not torch.max(torch.isnan(loss)), "nan-inverse-dynamic-training, obs: {}, obs_prime: {}, acts: {}, log_std: {}".format(obs,obs_prime,acts,log_prob)
        loss.backward()
        # if torch.max(torch.isnan(torch.Tensor(list(self.exploration_policy.inverse_dynamic.parameters())))):
            # print("nan-inverse-dynamic-training, obs: {}, obs_prime: {}, acts: {}, log_std: {}".format(obs,obs_prime,acts,log_prob))
            # for name, parms in self.exploration_policy.inverse_dynamic.named_parameters():
                # print('-->name:', name, '-->grad_requirs:',parms.requires_grad, ' -->grad_value:',parms.grad)

        self.inverse_dynamic_optimizer.step()

    def _do_reward_training(self, epoch):
        '''
            Train the discriminator
        '''
        self.disc_optimizer.zero_grad()

        keys = ['observations']
        if self.state_only:
            keys.append('next_observations')
        if self.sas:
            keys.append('next_observations')
            keys.append('actions')
        elif self.sss:
            keys.append('pred_observations')
            keys.append('next_observations')
        else:
            keys.append('actions')
        if self.wrap_absorbing:
            keys.append('absorbing')

        expert_batch = self.get_batch(self.disc_optim_batch_size, True, keys=keys)
        policy_batch = self.get_batch(self.disc_optim_batch_size, False, keys=keys)

        expert_obs = expert_batch['observations']
        policy_obs = policy_batch['observations']

        if self.wrap_absorbing:
            # pass
            expert_obs = torch.cat([expert_obs, expert_batch['absorbing'][:, 0:1]], dim=-1)
            policy_obs = torch.cat([policy_obs, policy_batch['absorbing'][:, 0:1]], dim=-1)

        expert_next_obs = expert_batch['next_observations']
        policy_next_obs = policy_batch['next_observations']

        if self.wrap_absorbing:
            # pass
            expert_next_obs = torch.cat([expert_next_obs, expert_batch['absorbing'][:, 1:]], dim=-1)
            policy_next_obs = torch.cat([policy_next_obs, policy_batch['absorbing'][:, 1:]], dim=-1)

        expert_inputs = [expert_obs, expert_next_obs]
        policy_inputs = [policy_obs, policy_next_obs]

        if self.sas:
            expert_acts = self.exploration_policy.inverse_dynamic(expert_obs, expert_next_obs)[0]
            policy_acts = policy_batch['actions']

            expert_inputs = [expert_obs, expert_acts, expert_next_obs]
            policy_inputs = [policy_obs, policy_acts, policy_next_obs]
        if self.sss:
            expert_pred_obs = expert_batch['next_observations']
            policy_pred_obs = policy_batch['pred_observations']

            expert_inputs = [expert_obs, expert_pred_obs, expert_next_obs]
            policy_inputs = [policy_obs, policy_pred_obs, policy_next_obs]

        expert_disc_input = torch.cat(expert_inputs, dim=1)
        policy_disc_input = torch.cat(policy_inputs, dim=1)

        if self.use_wgan:
            expert_logits = self.discriminator(expert_disc_input)
            policy_logits = self.discriminator(policy_disc_input)

            disc_ce_loss = -torch.sum(expert_logits) + torch.sum(policy_logits)
        else:
            disc_input = torch.cat([expert_disc_input, policy_disc_input], dim=0)

            disc_logits = self.discriminator(disc_input)
            disc_preds = (disc_logits > 0).type(disc_logits.data.type())
            disc_ce_loss = self.bce(disc_logits, self.bce_targets)
            accuracy = (disc_preds == self.bce_targets).type(torch.FloatTensor).mean()

        if self.use_grad_pen:
            eps = ptu.rand(expert_obs.size(0), 1)
            eps.to(ptu.device)
            
            interp_obs = eps*expert_disc_input + (1-eps)*policy_disc_input
            interp_obs = interp_obs.detach()
            interp_obs.requires_grad_(True)

            gradients = autograd.grad(
                outputs=self.discriminator(interp_obs).sum(),
                inputs=[interp_obs],
                create_graph=True, retain_graph=True, only_inputs=True
            )
            total_grad = gradients[0]
            
            # GP from Gulrajani et al.
            gradient_penalty = ((total_grad.norm(2, dim=1) - 1) ** 2).mean()
            disc_grad_pen_loss = gradient_penalty * self.grad_pen_weight

            # # GP from Mescheder et al.
            # gradient_penalty = (total_grad.norm(2, dim=1) ** 2).mean()
            # disc_grad_pen_loss = gradient_penalty * 0.5 * self.grad_pen_weight
        else:
            disc_grad_pen_loss = 0.0

        disc_total_loss = disc_ce_loss + disc_grad_pen_loss
        assert not torch.max(
            torch.isnan(disc_total_loss)), "nan-reward-training, disc_ce_loss: {}, disc_grad_pen_loss: {}".format(
            disc_ce_loss, disc_grad_pen_loss)
        disc_total_loss.backward()
        self.disc_optimizer.step()

        """
        Save some statistics for eval
        """
        if self.disc_eval_statistics is None:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.disc_eval_statistics = OrderedDict()
            
            self.disc_eval_statistics['Disc CE Loss'] = np.mean(ptu.get_numpy(disc_ce_loss))
            if not self.use_wgan:
                self.disc_eval_statistics['Disc Acc'] = np.mean(ptu.get_numpy(accuracy))
            if self.use_wgan:
                self.disc_eval_statistics['Expert D Logits'] = np.mean(ptu.get_numpy(expert_logits))
                self.disc_eval_statistics['Policy D Logits'] = np.mean(ptu.get_numpy(policy_logits))
            if self.use_grad_pen:
                self.disc_eval_statistics['Grad Pen'] = np.mean(ptu.get_numpy(gradient_penalty))
                self.disc_eval_statistics['Grad Pen W'] = np.mean(self.grad_pen_weight)


    def _do_policy_training(self, epoch):
        if self.policy_optim_batch_size_from_expert > 0:
            policy_batch_from_policy_buffer = self.get_batch(
                self.policy_optim_batch_size - self.policy_optim_batch_size_from_expert, False)
            policy_batch_from_expert_buffer = self.get_batch(
                self.policy_optim_batch_size_from_expert, True)
            policy_batch = {}
            for k in policy_batch_from_policy_buffer:
                policy_batch[k] = torch.cat(
                    [
                        policy_batch_from_policy_buffer[k],
                        policy_batch_from_expert_buffer[k]
                    ],
                    dim=0
                )
        else:
            policy_batch = self.get_batch(self.policy_optim_batch_size, False, multi_step=self.multi_step)
        
        obs = policy_batch['observations']
        next_obs = policy_batch['next_observations']

        if self.wrap_absorbing:
            # pass
            obs = torch.cat([obs, policy_batch['absorbing'][:, 0:1]], dim=-1)
            next_obs = torch.cat([next_obs, policy_batch['absorbing'][:, 1:]], dim=-1)

        policy_inputs = [obs, next_obs]

        if self.sas:
            acts = policy_batch['actions']
            policy_inputs = [obs, acts, next_obs]
        if self.sss:
            pred_next_obs = policy_batch['pred_observations']
            policy_inputs = [obs, pred_next_obs, next_obs]

        self.discriminator.eval()
        disc_input = torch.cat(policy_inputs, dim=1)
        disc_logits = self.discriminator(disc_input).detach()
        self.discriminator.train()

        # compute the reward using the algorithm
        if self.mode == 'airl':
            # If you compute log(D) - log(1-D) then you just get the logits
            policy_batch['rewards'] = disc_logits
        elif self.mode == 'gail':
            policy_batch['rewards'] = F.softplus(disc_logits, beta=1) # F.softplus(disc_logits, beta=-1)
        elif self.mode == 'gail2':
            policy_batch['rewards'] = F.softplus(disc_logits, beta=-1) # F.softplus(disc_logits, beta=-1)
        else: # fairl
            policy_batch['rewards'] = torch.exp(disc_logits)*(-1.0*disc_logits)

        self.disc_eval_statistics['Disc Rew Mean'] = np.mean(ptu.get_numpy(policy_batch['rewards']))
        self.disc_eval_statistics['Disc Rew Std'] = np.std(ptu.get_numpy(policy_batch['rewards']))
        self.disc_eval_statistics['Disc Rew Max'] = np.max(ptu.get_numpy(policy_batch['rewards']))
        self.disc_eval_statistics['Disc Rew Min'] = np.min(ptu.get_numpy(policy_batch['rewards']))

        if self.reward_penelty:
            agent_pred_obs = self.exploration_policy.state_predictor(obs)
            pred_mse = (agent_pred_obs - next_obs) ** 2
            pred_mse = torch.sum(pred_mse, axis=-1, keepdim=True)
            reward_penelty = self.penelty_weight * pred_mse
            policy_batch['rewards'] -= reward_penelty

            self.disc_eval_statistics['Penelty Rew Mean'] = np.mean(ptu.get_numpy(reward_penelty))
            self.disc_eval_statistics['Penelty Rew Std'] = np.std(ptu.get_numpy(reward_penelty))
            self.disc_eval_statistics['Penelty Rew Max'] = np.max(ptu.get_numpy(reward_penelty))
            self.disc_eval_statistics['Penelty Rew Min'] = np.min(ptu.get_numpy(reward_penelty))

        if self.clip_max_rews:
            policy_batch['rewards'] = torch.clamp(policy_batch['rewards'], max=self.rew_clip_max)
        if self.clip_min_rews:
            policy_batch['rewards'] = torch.clamp(policy_batch['rewards'], min=self.rew_clip_min)

        # policy optimization step
        if self.union or self.union_sp:
            exp_keys = ['observations', 'next_observations', 'actions']
            for i in np.arange(1, self.step_num + 1):
                exp_keys.append('next{}_observations'.format(i))

            expert_batch = self.get_batch(
                self.state_predictor_optim_batch_size,
                keys=exp_keys,
                from_expert=True,
                multi_step=self.multi_step,
                step_num=self.step_num
            )
            inv_batch = expert_batch
            if self.inv_buffer:
                inv_batch = self.get_batch(
                    self.state_predictor_optim_batch_size,
                    keys=['observations', 'next_observations', 'actions'],
                    from_expert=False,
                    from_inv=True,
                )
            self.policy_trainer.train_step(policy_batch, qss=self.qss,
                                           alpha=self.state_predictor_alpha,
                                           beta=self.inverse_dynamic_beta,
                                           expert_batch=expert_batch,
                                           inv_batch=inv_batch,
                                           state_diff=self.state_diff,
                                           multi_step=self.multi_step,
                                           step_num=self.step_num,
                                           target_state_predictor=self.target_state_predictor)
            if self.multi_step:
                ptu.copy_model_params_from_to(self.exploration_policy.state_predictor, self.target_state_predictor)
        else:
            self.policy_trainer.train_step(policy_batch, qss=self.qss)

        self.disc_eval_statistics['Total Rew Mean'] = np.mean(ptu.get_numpy(policy_batch['rewards']))
        self.disc_eval_statistics['Total Rew Std'] = np.std(ptu.get_numpy(policy_batch['rewards']))
        self.disc_eval_statistics['Total Rew Max'] = np.max(ptu.get_numpy(policy_batch['rewards']))
        self.disc_eval_statistics['Total Rew Min'] = np.min(ptu.get_numpy(policy_batch['rewards']))

    def _handle_step(
        self,
        observation,
        action,
        reward,
        next_observation,
        terminal,
        timeout,
        pred_obs,
        absorbing,
        agent_info,
        env_info,
    ):
        """
        Implement anything that needs to happen after every step
        :return:
        """
        self._current_path_builder.add_all(
            observations=observation,
            actions=action,
            rewards=reward,
            next_observations=next_observation,
            terminals=terminal,
            pred_observations=pred_obs,
            absorbing=absorbing,
            agent_infos=agent_info,
            env_infos=env_info,
        )
        self.replay_buffer.add_sample(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            next_observation=next_observation,
            timeout=timeout,
            pred_observations=pred_obs,
            absorbing=absorbing,
            agent_info=agent_info,
            env_info=env_info,
        )
        if self.inv_buffer:
            self.inv_replay_buffer.add_sample(
                observation=observation,
                action=action,
                reward=reward,
                terminal=terminal,
                next_observation=next_observation,
                timeout=timeout,
                absorbing=absorbing,
                agent_info=agent_info,
                env_info=env_info,
            )
    
    @property
    def networks(self):
        res = [self.discriminator] + self.policy_trainer.networks + [self.policy_trainer.policy.state_predictor,self.policy_trainer.policy.inverse_dynamic]
        if self.multi_step:
            res += [self.target_state_predictor]
        return res


    def get_epoch_snapshot(self, epoch):
        snapshot = super().get_epoch_snapshot(epoch)
        snapshot.update(disc=self.discriminator)
        snapshot.update(dict(
            qf1=self.policy_trainer.qf1,
            qf2=self.policy_trainer.qf2,
            state_predictor=self.policy_trainer.policy.state_predictor,
            inverse_dynamic=self.policy_trainer.policy.inverse_dynamic,
            vf=self.policy_trainer.vf,
            target_vf=self.policy_trainer.target_vf,
        ))
        return snapshot


    def to(self, device):
        self.bce.to(ptu.device)
        self.bce_targets = self.bce_targets.to(ptu.device)
        super().to(device)
