import tensorflow as tf
import tensorflow_probability as tfp
from gail_models import LatentKLPreprocessing, KLPreprocessing, FactoredPreprocessing

LN2 = 0.69314718056


def huber_loss(errors, delta=1):
    loss = tf.where(errors < delta, tf.pow(errors, 2),
                    delta * tf.abs(errors))
    return loss


class StatisticsNet(tf.keras.layers.Layer):
    def __init__(self, mi_layers):
        super(StatisticsNet, self).__init__()
        self._mi_layers = mi_layers

    def layers_out(self, inputs):
        out = inputs
        for layer in self._mi_layers:
            out = layer(out)
        return out

    @tf.function
    def call(self, inputs):
        out = tf.concat(inputs, axis=-1)
        score = self.layers_out(out)
        return score


class MIEstimator(StatisticsNet):
    def __init__(self, mi_layers):
        super(MIEstimator, self).__init__(mi_layers)


class KLEstimator(StatisticsNet):
    def __init__(self, kl_layers):
        super(KLEstimator, self).__init__(kl_layers)


class DisentanGAIL(tf.keras.Model):
    def __init__(self,
                 agent,
                 make_discriminator,
                 make_preprocessing,
                 expert_buffer,
                 prior_expert_buffer=None,
                 prior_agent_buffer=None,
                 make_kl_estimator=None,
                 make_mi_estimator=None,
                 make_mi2_estimator=None,
                 use_min_double_mi=False,
                 double_preprocessing=False,
                 d_loss='ce',
                 d_optimizer=tf.keras.optimizers.Adam(1e-3),
                 mi_optimizer=tf.keras.optimizers.Adam(1e-3),
                 label_smoothing=0.0,
                 expert_shuffle=False,
                 stab_const=0.0,
                 grad_penalty=0.0,
                 latent_grad_penalty=False,
                 huber_grad_penalty=False,
                 kl_regularization=0.0,
                 entropy_constant=0.0,
                 kl_estimation='Gaussian',
                 kl_optimizer=tf.keras.optimizers.Adam(1e-4),
                 adaptive_kl=False,
                 max_kl_regularization=1e-1,
                 min_kl_regularization=1e-6,
                 kl_elites=1.0,
                 mi_constant=0.0,
                 adaptive_mi=False,
                 max_mi=1.0,
                 min_mi=0.5,
                 prior_mi_constant=0.0,
                 max_mi_prior=0.01,
                 negative_priors=False,
                 min_mi_prior_constant=1e-4,
                 use_dual_mi=False,
                 mi_lagrangian_optimizer=tf.keras.optimizers.Adam(1e-3),
                 max_mi_constant=10,
                 min_mi_constant=1e-4,
                 unbiased_mi=False,
                 unbiased_mi_decay=0.99,
                 clip_mi_predictions=False,
                 mi_elites=1.0,
                 max_kl_clip=1e3,
                 weighted_dac_buffer=False,
                 pre_gradient_projection='None',
                 im_side=32,
                 past_frames=4):

        super(DisentanGAIL, self).__init__()
        self._disc = make_discriminator()
        self._pre = make_preprocessing()
        self._d_loss = d_loss
        self._mi_est = None
        self._unbiased_mi = unbiased_mi
        self._unbiased_mi_decay = unbiased_mi_decay
        self._clip_mi_predictions = clip_mi_predictions

        self._double_preprocessing = double_preprocessing
        self._weighted_dac_buffer = weighted_dac_buffer
        self._use_min_double_mi = use_min_double_mi
        self._use_dual_mi = use_dual_mi
        self._pre_gradient_projection = pre_gradient_projection
        if self._double_preprocessing:
            self._pre2 = make_preprocessing()
        else:
            self._pre2 = None
        if isinstance(self._pre, KLPreprocessing):
            self._lat = False
        elif isinstance(self._pre, LatentKLPreprocessing) or \
                isinstance(self._pre, FactoredPreprocessing):
            self._lat = True
        else:
            raise NotImplementedError
        self._agent = agent
        self._past_frames = past_frames
        self._exp_buff = expert_buffer

        self._expert_shuffle = expert_shuffle
        self._sb = stab_const
        self._grad_penalty = grad_penalty
        self._huber_grad_penalty = huber_grad_penalty
        self._latent_grad_penalty = latent_grad_penalty
        self._kl_regularization = tf.Variable(kl_regularization, trainable=False)
        self._kl_estimation = kl_estimation
        if self._kl_estimation == 'Gaussian':
            self.kl_entropy_regularization = self._gaussian_kl_entropy_regularization
        elif self._kl_estimation == 'DV':
            assert make_kl_estimator is not None
            self._kl_est = make_kl_estimator()
            self._train_kl = self.make_dv_kl_training_op(self._kl_est, kl_optimizer)
            self.kl_entropy_regularization = self._donsker_varadhan_kl_entropy_regularization
        else:
            raise NotImplementedError
        self._adaptive_kl = adaptive_kl
        self._max_kl_regularization = max_kl_regularization
        self._min_kl_regularization = min_kl_regularization
        self._entropy_constant = entropy_constant
        self._max_kl_clip = max_kl_clip

        self._adaptive_mi = adaptive_mi
        if self._use_dual_mi:
            self._log_mi_constant = tf.Variable(tf.math.log(mi_constant))
            self._mi_constant = tf.exp(self._log_mi_constant)
            self.update_dual_mi_constant = self.make_dual_mi_constant_update(
                log_mi_constant=self._log_mi_constant,
                max_mi=self._max_mi,
                optimizer=mi_lagrangian_optimizer
            )
            self._log_max_mi_constant = tf.math.log(max_mi_constant)
            self._log_min_mi_constant = tf.math.log(min_mi_constant)
        else:
            self._mi_constant = tf.Variable(mi_constant, trainable=False)
        self._max_mi = max_mi
        self._min_mi = min_mi
        self._max_mi_constant = max_mi_constant
        self._min_mi_constant = min_mi_constant
        if self._adaptive_mi:
            assert self._mi_constant > 0.0, 'When using adaptive mi penalty, ' \
                                            'initialize the mi constant to some ' \
                                            'positive value'
            assert self._max_mi > self._min_mi, 'The maximum MI for the adaptive penalty' \
                                                'should be greater than the minimum MI'
            assert self._max_mi_constant > self._min_mi_constant, 'The maximum MI constant' \
                                                                  'should be greater than ' \
                                                                  'the minimum MI constant'
        if self._adaptive_kl:
            assert self._kl_regularization > 0.0, 'When using adaptive kl regularization, ' \
                                                  'initialize the kl regularization term to some ' \
                                                  'positive value'
            assert not self._adaptive_mi, 'Incompatible to utilize both adaptive kl and mi' \
                                          'regularization'
            assert self._max_kl_regularization > self._min_kl_regularization, \
                'The maximum KL regularization term should be greater than the minimum KL' \
                'regularization term'

        if self._mi_constant > 0.0 or self._adaptive_kl:
            assert make_mi_estimator is not None
            self._mi_est = make_mi_estimator()
            if self._unbiased_mi:
                self._unbiased_mi_ma = tf.Variable(1.0, trainable=False)
            else:
                self._unbiased_mi_ma = None
            self._train_mi = self.make_mi_training_op(self._mi_est, mi_optimizer,
                                                      self._unbiased_mi_ma)
            if make_mi2_estimator is not None:
                self._mi2_est = make_mi2_estimator()
                if self._unbiased_mi:
                    self._unbiased_mi_ma2 = tf.Variable(1.0, trainable=False)
                else:
                    self._unbiased_mi_ma2 = None
                self._train_mi2 = self.make_mi_training_op(self._mi2_est, mi_optimizer,
                                                           self._unbiased_mi_ma2)
                self._double_mi = True
            else:
                self._double_mi = False
        self._mi_prior_constant = prior_mi_constant
        self._max_mi_prior = max_mi_prior

        self._pr_exp_buff = prior_expert_buffer
        self._pr_age_buff = prior_agent_buffer
        if self._mi_prior_constant > 0.0:
            assert self._pr_exp_buff is not None
            assert self._pr_age_buff is not None
            self._log_min_mi_prior_constant = tf.math.log(min_mi_prior_constant)
            self._prior_domains_data = True
            self._log_mi_prior_constant = tf.Variable(tf.math.log(prior_mi_constant))
            self.update_dual_mi_prior_constant = self.make_dual_mi_constant_update(
                log_mi_constant=self._log_mi_prior_constant,
                max_mi=self._max_mi_prior,
                optimizer=mi_lagrangian_optimizer
            )
        else:
            self._prior_domains_data = False
        self._negative_priors = negative_priors
        if self._negative_priors:
            assert self._pr_exp_buff is not None
            assert self._pr_age_buff is not None


        if self._pre_gradient_projection != 'None':
            init_input = tf.random.uniform([1, past_frames, im_side, im_side, 3])
            _ = self._pre(init_input)
            if self._double_preprocessing:
                _ = self._pre2(init_input)
            self._trainable_pre_weight_shapes = [tf.shape(w) for w in self.get_pre_weights()]
            self._trainable_pre_weight_dims = [tf.reduce_prod(shape) for shape in self._trainable_pre_weight_shapes]

        self._train_disc = self.make_disc_training_op(d_optimizer, label_smoothing)

    def get_expert_pre_batch(self, e_ims_batch):
        if self._double_preprocessing:
            return self._pre2(e_ims_batch)
        else:
            return self._pre(e_ims_batch)

    def get_pre_batch(self, l_ims_batch, e_ims_batch):
        l_pre_batch = self._pre(l_ims_batch)
        e_pre_batch = self.get_expert_pre_batch(e_ims_batch)
        return l_pre_batch, e_pre_batch

    def get_pre_weights(self, ):
        pre_weights = self._pre.trainable_weights
        if self._double_preprocessing:
            pre_weights += self._pre2.trainable_weights
        return pre_weights

    def reshape_pre_encodings(self, encodings):
        encodings_shape = tf.shape(encodings)
        return tf.reshape(encodings, [encodings_shape[0] * self._past_frames,
                                      encodings_shape[1] // self._past_frames])

    def call(self, inputs):
        out = self._agent(inputs['obs'])
        out['pre'], out['exp_pre'] = self.get_pre_batch(inputs['ims'], inputs['ims'])
        reshaped_pre = self.reshape_pre_encodings(out['pre'])
        if self._mi_constant > 0.0 or self._adaptive_kl:
            n_inputs = tf.shape(reshaped_pre)[0]
            mi_inputs = tf.concat([reshaped_pre, tf.ones([n_inputs, 1])], axis=1)
            out['mi'] = self._mi_est(mi_inputs)
            if self._double_mi:
                out['mi2'] = self._mi2_est(mi_inputs)
        if self._kl_regularization > 0.0 and self._kl_estimation == 'DV':
            out['kl'] = self._kl_est(reshaped_pre)
        out['rew'] = self._disc.get_reward(out['pre'])
        return out

    def _gaussian_kl_entropy_regularization(self, agent_encodings, expert_encodings):
        if self._lat:
            agent_encodings = self.reshape_pre_encodings(agent_encodings)
            expert_encodings = self.reshape_pre_encodings(expert_encodings)
        agent_mean, agent_variance = tf.nn.moments(agent_encodings, axes=0)
        expert_mean, expert_variance = tf.nn.moments(expert_encodings, axes=0)
        agent_distribution = tfp.distributions.MultivariateNormalDiag(loc=agent_mean,
                                                                      scale_diag=tf.sqrt(agent_variance))
        expert_distribution = tfp.distributions.MultivariateNormalDiag(loc=expert_mean,
                                                                       scale_diag=tf.sqrt(expert_variance))
        kl_divergence = agent_distribution.kl_divergence(expert_distribution)
        kl_divergence = tf.where(tf.math.is_nan(kl_divergence), tf.ones_like(kl_divergence) * self._max_kl_clip,
                                 kl_divergence)
        kl_divergence = tf.reduce_mean(tf.minimum(kl_divergence, self._max_kl_clip))
        expert_entropy = tf.reduce_mean(expert_distribution.entropy())
        return kl_divergence, expert_entropy

    def _donsker_varadhan_kl_entropy_regularization(self, agent_encodings, expert_encodings):
        if self._lat:
            agent_encodings = self.reshape_pre_encodings(agent_encodings)
            expert_encodings = self.reshape_pre_encodings(expert_encodings)
        expert_mean, expert_variance = tf.nn.moments(expert_encodings, axes=0)
        expert_distribution = tfp.distributions.MultivariateNormalDiag(loc=expert_mean,
                                                                       scale_diag=tf.sqrt(expert_variance))
        kl_divergence = self.dv_kl(self._kl_est, agent_encodings, expert_encodings)
        expert_entropy = 0.0
        return kl_divergence, expert_entropy

    def gradient_penalty(self, l_ims_batch, e_ims_batch):
        alpha = tf.random.uniform((tf.shape(l_ims_batch)[0], 1, 1, 1, 1))
        m_ims_batch = l_ims_batch * alpha + e_ims_batch * (1 - alpha)
        with tf.GradientTape() as tape2:
            tape2.watch(m_ims_batch)
            m_pre_batch = self._pre(m_ims_batch)
            m_disc = self._disc(m_pre_batch)
            grad_obs = tape2.gradient(m_disc, [m_ims_batch])[0]
        grad_obs = tf.concat(grad_obs, axis=1)
        norm_deviations = tf.norm(grad_obs, axis=1) - 1
        if self._huber_grad_penalty:
            norm_losses = huber_loss(norm_deviations, delta=1)
        else:
            norm_losses = tf.pow(norm_deviations, 2)
        gp_loss = tf.reduce_mean(norm_losses)
        return gp_loss

    def pre_gradient_penalty(self, l_pre_batch, e_pre_batch):
        alpha = tf.random.uniform((tf.shape(l_pre_batch)[0], 1, 1, 1, 1))
        m_pre_batch = l_pre_batch * alpha + e_pre_batch * (1 - alpha)
        with tf.GradientTape() as tape2:
            tape2.watch(m_pre_batch)
            m_disc = self._disc(m_pre_batch)
            grad_obs = tape2.gradient(m_disc, [m_pre_batch])[0]
        grad_obs = tf.concat(grad_obs, axis=1)
        norm_deviations = tf.norm(grad_obs, axis=1) - 1
        if self._huber_grad_penalty:
            norm_losses = huber_loss(norm_deviations, delta=1)
        else:
            norm_losses = tf.pow(norm_deviations, 2)
        gp_loss = tf.reduce_mean(norm_losses)
        return gp_loss

    def gan_loss(self, l_disc_prob, e_disc_prob, lb):
        if self._d_loss == 'hinge':
            return self.hinge_gan_loss(l_disc_prob, e_disc_prob)
        elif self._d_loss == 'ce':
            return self.ce_gan_loss(l_disc_prob, e_disc_prob, lb)
        else:
            raise NotImplementedError

    def hinge_gan_loss(self, l_disc_prob, e_disc_prob):
        l_disc_loss = tf.nn.relu(1 - l_disc_prob)
        e_disc_loss = tf.nn.relu(1 + e_disc_prob)
        return tf.reduce_mean(l_disc_loss) + tf.reduce_mean(e_disc_loss)

    def ce_gan_loss(self, l_disc_prob, e_disc_prob, lb):
        labels = tf.concat([tf.zeros_like(l_disc_prob),
                            tf.ones_like(e_disc_prob)], axis=0)
        probs = tf.concat([l_disc_prob, e_disc_prob], axis=0)
        return tf.losses.binary_crossentropy(labels, probs, label_smoothing=lb)

    @staticmethod
    def dv_kl(est, p_samples, q_samples):
        p_samples_estimate = tf.reduce_mean(est(p_samples))
        q_samples_estimate = tf.math.log(tf.reduce_mean(tf.exp(est(q_samples))))
        return (p_samples_estimate - q_samples_estimate) / LN2

    @staticmethod
    def get_mi_batches(l_pre_batch, e_pre_batch, past_frames=4):
        l_pre_batch_shape = l_pre_batch.get_shape()
        e_pre_batch_shape = e_pre_batch.get_shape()
        l_pre_batch_n = l_pre_batch_shape[0] * past_frames
        e_pre_batch_n = e_pre_batch_shape[0] * past_frames
        l_pre_batch = tf.reshape(l_pre_batch, [e_pre_batch_n, -1])
        e_pre_batch = tf.reshape(e_pre_batch, [l_pre_batch_n, -1])
        input_correct_batch = tf.concat([l_pre_batch, e_pre_batch], axis=0)
        domain_labels = tf.concat([tf.zeros([l_pre_batch_n, 1]),
                                   tf.ones([e_pre_batch_n, 1])], axis=0)
        shuffled_domain_labels = tf.random.shuffle(domain_labels)
        positive_ordering = tf.concat([input_correct_batch, domain_labels],
                                      axis=1)
        negative_ordering = tf.concat([input_correct_batch, shuffled_domain_labels],
                                      axis=1)
        return positive_ordering, negative_ordering

    def mi_loss(self, mi_est, l_pre_batch, e_pre_batch):
        positive_ordering, negative_ordering = self.get_mi_batches(
            l_pre_batch, e_pre_batch, self._past_frames)
        return -1 * self.dv_kl(mi_est, positive_ordering, negative_ordering)

    def make_dv_kl_training_op(self, kl_est, optimizer):
        def train(l_ims_batch, e_ims_batch):
            with tf.GradientTape() as tape:
                l_pre_batch, e_pre_batch = self.get_pre_batch(l_ims_batch, e_ims_batch)
                agent_encodings = self.reshape_pre_encodings(l_pre_batch)
                expert_encodings = self.reshape_pre_encodings(e_pre_batch)
                kl_estimate = self.dv_kl(kl_est, agent_encodings, expert_encodings)
                gradients = tape.gradient(-1 * kl_estimate, kl_est.trainable_weights)
            optimizer.apply_gradients(zip(gradients, kl_est.trainable_weights))
            return kl_estimate

        return tf.function(train)

    def make_mi_training_op(self, mi_est, optimizer, mi_ma=None):
        if mi_ma is None:
            def train(l_ims_batch, e_ims_batch):
                with tf.GradientTape() as tape:
                    l_pre_batch, e_pre_batch = self.get_pre_batch(l_ims_batch, e_ims_batch)
                    mi_loss = self.mi_loss(mi_est, l_pre_batch, e_pre_batch)
                    gradients = tape.gradient(mi_loss, mi_est.trainable_weights)
                optimizer.apply_gradients(zip(gradients, mi_est.trainable_weights))
                return mi_loss
        else:
            def loss_fn(mi_est, l_pre_batch, e_pre_batch):
                p_samples, q_samples = self.get_mi_batches(l_pre_batch, e_pre_batch, self._past_frames)
                p_samples_estimate = tf.reduce_mean(mi_est(p_samples))
                batch_q_exp_samples_estimate = tf.reduce_mean(tf.exp(mi_est(q_samples)))
                mi_ma.assign(tf.stop_gradient(self._unbiased_mi_decay * mi_ma +
                                              (1 - self._unbiased_mi_decay) *
                                              batch_q_exp_samples_estimate))
                unbiased_loss = -(p_samples_estimate - batch_q_exp_samples_estimate / mi_ma) / LN2
                mi_loss = -(p_samples_estimate - tf.math.log(batch_q_exp_samples_estimate)) / LN2
                return unbiased_loss, mi_loss

            def train(l_ims_batch, e_ims_batch):
                with tf.GradientTape() as tape:
                    l_pre_batch, e_pre_batch = self.get_pre_batch(l_ims_batch, e_ims_batch)
                    unbiased_loss, mi_loss = loss_fn(mi_est, l_pre_batch, e_pre_batch)
                    gradients = tape.gradient(unbiased_loss, mi_est.trainable_weights)
                optimizer.apply_gradients(zip(gradients, mi_est.trainable_weights))
                return mi_loss
        return tf.function(train)

    def make_disc_training_op(self, optimizer, lb):
        def compute_losses(l_ims_batch, e_ims_batch, s_e_ims_batch, l_prior_ims_batch, e_prior_ims_batch):
            l_pre_batch, e_pre_batch, l_prior_pre_batch, e_prior_pre_batch = gather_pre_batches(l_ims_batch,
                                                                                                e_ims_batch,
                                                                                                l_prior_ims_batch,
                                                                                                e_prior_ims_batch)
            l_disc_prob, e_disc_prob = get_gan_probs(l_pre_batch, e_pre_batch, s_e_ims_batch, l_prior_pre_batch,
                                                     e_prior_pre_batch)
            gan_loss = self.gan_loss(l_disc_prob, e_disc_prob, lb)

            gp_loss = get_gp(l_pre_batch, e_pre_batch, l_ims_batch, e_ims_batch)

            mi = get_mi(l_pre_batch, e_pre_batch)
            prior_mi = get_prior_mi(l_prior_pre_batch, e_prior_pre_batch)
            kl, entropy = self.kl_entropy_regularization(l_pre_batch, e_pre_batch)
            return gan_loss, gp_loss, mi, prior_mi, kl, entropy

        # Gradient Penalty
        if self._grad_penalty > 0.0:
            if self._latent_grad_penalty:
                def get_gp(l_pre_batch, e_pre_batch, l_ims_batch, e_ims_batch):
                    return self._grad_penalty * self.pre_gradient_penalty(l_pre_batch, e_pre_batch)
            else:
                def get_gp(l_pre_batch, e_pre_batch, l_ims_batch, e_ims_batch):
                    return self._grad_penalty * self.gradient_penalty(l_ims_batch, e_ims_batch)
        else:
            def get_gp(l_pre_batch, e_pre_batch, l_ims_batch, e_ims_batch):
                return 0.0

        # MI
        if self._mi_constant > 0.0:
            if self._use_dual_mi:
                def get_mi_constant():
                    return tf.exp(self._log_mi_constant)
            else:
                def get_mi_constant():
                    return self._mi_constant

            if self._double_mi and self._use_min_double_mi:
                def get_mi(l_pre_batch, e_pre_batch):
                    return tf.math.maximum(
                        -1 * tf.math.minimum(self.mi_loss(self._mi_est, l_pre_batch, e_pre_batch),
                                             self.mi_loss(self._mi2_est, l_pre_batch, e_pre_batch)),
                        0.0)
            else:
                def get_mi(l_pre_batch, e_pre_batch):
                    return tf.math.maximum(-1 * self.mi_loss(self._mi_est, l_pre_batch, e_pre_batch), 0.0)
        else:
            def get_mi_constant():
                return 0.0

            def get_mi(l_pre_batch, e_pre_batch):
                return 0.0

        # (Prior switch) data processing preprocessing
        if self._prior_domains_data or self._negative_priors:
            def gather_pre_batches(l_ims_batch, e_ims_batch, l_prior_ims_batch, e_prior_ims_batch):
                l_pre_batch, e_pre_batch = self.get_pre_batch(l_ims_batch, e_ims_batch)
                l_prior_pre_batch, e_prior_pre_batch = self.get_pre_batch(l_prior_ims_batch, e_prior_ims_batch)
                return l_pre_batch, e_pre_batch, l_prior_pre_batch, e_prior_pre_batch
        else:
            def gather_pre_batches(l_ims_batch, e_ims_batch, l_prior_ims_batch, e_prior_ims_batch):
                l_pre_batch, e_pre_batch = self.get_pre_batch(l_ims_batch, e_ims_batch)
                return l_pre_batch, e_pre_batch, None, None

        # Prior MI
        if self._prior_domains_data:
            def get_prior_mi_constant():
                return tf.exp(self._log_mi_prior_constant)

            if self._double_mi and self._use_min_double_mi:
                def get_prior_mi(l_prior_pre_batch, e_prior_pre_batch):
                    return tf.math.maximum(
                        -1 * tf.math.minimum(self.mi_loss(self._mi_est, l_prior_pre_batch, e_prior_pre_batch),
                                             self.mi_loss(self._mi2_est, l_prior_pre_batch, e_prior_pre_batch)),
                        0.0)
            else:
                def get_prior_mi(l_prior_pre_batch, e_prior_pre_batch):
                    return tf.math.maximum(-1 * self.mi_loss(self._mi_est, l_prior_pre_batch, e_prior_pre_batch), 0.0)
        else:
            def get_prior_mi_constant():
                return 0.0

            def get_prior_mi(l_prior_ims_batch, e_prior_ims_batch):
                return 0.0

        # Expert shuffle
        if self._expert_shuffle:
            def augment_l_pre_batch(l_pre_batch, s_e_ims_batch):
                s_e_pre_batch = self.get_expert_pre_batch(s_e_ims_batch)
                return tf.concat([l_pre_batch, s_e_pre_batch], axis=0)
        else:
            def augment_l_pre_batch(l_pre_batch, s_e_ims_batch):
                return l_pre_batch

        if self._negative_priors:
            def get_gan_probs(l_pre_batch, e_pre_batch, s_e_ims_batch, l_prior_pre_batch, e_prior_pre_batch):
                l_pre_batch = augment_l_pre_batch(l_pre_batch, s_e_ims_batch)
                l_disc_prob = self._disc.get_prob(l_pre_batch) + self._sb
                e_disc_prob = self._disc.get_prob(e_pre_batch) + self._sb
                l_prior_disc_prob = self._disc.get_prob(l_prior_pre_batch) + self._sb
                e_prior_disc_prob = self._disc.get_prob(e_prior_pre_batch) + self._sb
                return tf.concat([l_disc_prob, l_prior_disc_prob, e_prior_disc_prob], axis=0), e_disc_prob
        else:
            def get_gan_probs(l_pre_batch, e_pre_batch, s_e_ims_batch, l_prior_pre_batch, e_prior_pre_batch):
                l_pre_batch = augment_l_pre_batch(l_pre_batch, s_e_ims_batch)
                l_disc_prob = self._disc.get_prob(l_pre_batch) + self._sb
                e_disc_prob = self._disc.get_prob(e_pre_batch) + self._sb
                return l_disc_prob, e_disc_prob

        if self._pre_gradient_projection == 'None':
            def get_gradients(l_ims_batch, e_ims_batch, s_e_ims_batch, l_prior_ims_batch, e_prior_ims_batch):
                with tf.GradientTape(persistent=True) as tape:
                    gan_loss, gp_loss, mi, prior_mi, kl, entropy = compute_losses(l_ims_batch=l_ims_batch,
                                                                                  e_ims_batch=e_ims_batch,
                                                                                  s_e_ims_batch=s_e_ims_batch,
                                                                                  l_prior_ims_batch=l_prior_ims_batch,
                                                                                  e_prior_ims_batch=e_prior_ims_batch)

                    weighted_mi = mi * get_mi_constant() + get_prior_mi_constant() * prior_mi
                    loss = gan_loss + gp_loss + weighted_mi
                    if self._kl_regularization > 0.0 or self._entropy_constant > 0.0:
                        loss += (self._kl_regularization * kl - self._entropy_constant * entropy)
                gradients = tape.gradient(loss, self._disc.trainable_weights + self.get_pre_weights())
                losses = (gan_loss, gp_loss, kl, entropy)
                del tape
                return gradients, losses

        elif self._pre_gradient_projection == 'gan<->mi':
            def get_gradients(l_ims_batch, e_ims_batch, s_e_ims_batch, l_prior_ims_batch, e_prior_ims_batch):
                with tf.GradientTape(persistent=True) as tape:
                    gan_loss, gp_loss, mi, prior_mi, kl, entropy = compute_losses(l_ims_batch, e_ims_batch,
                                                                                  s_e_ims_batch,
                                                                                  l_prior_ims_batch, e_prior_ims_batch)
                    weighted_mi = mi * get_mi_constant() + get_prior_mi_constant() * prior_mi
                    loss = gan_loss + gp_loss + weighted_mi
                    residual_loss = gp_loss
                    if self._kl_regularization > 0.0 or self._entropy_constant > 0.0:
                        kl_entropy_terms = (self._kl_regularization * kl - self._entropy_constant * entropy)
                        loss += kl_entropy_terms
                        residual_loss += kl_entropy_terms
                losses = (gan_loss, gp_loss, kl, entropy)

                pre_gan_gradients = tape.gradient(gan_loss, self.get_pre_weights())
                pre_gan_gradients_flatten = [tf.reshape(gradient, [-1]) for gradient in pre_gan_gradients]
                pre_mi_gradients = tape.gradient(weighted_mi, self.get_pre_weights())
                pre_mi_gradients_flatten = [tf.reshape(gradient, [-1]) for gradient in pre_mi_gradients]
                pre_gan_gradients_vec = tf.concat(pre_gan_gradients_flatten, axis=0)
                pre_mi_gradients_vec = tf.concat(pre_mi_gradients_flatten, axis=0)
                pre_gan_gradients_mag = tf.norm(pre_gan_gradients_vec)
                pre_mi_gradients_mag = tf.norm(pre_mi_gradients_vec)

                dot_prod = tf.reduce_sum(pre_gan_gradients_vec * pre_mi_gradients_vec)
                if dot_prod < 0.0 and mi > self._max_mi:
                    pre_gan_gradients_vec_corrected = pre_gan_gradients_vec - dot_prod / (pre_mi_gradients_mag + 1e-12)
                    pre_gradients_vec = pre_gan_gradients_vec_corrected + pre_mi_gradients_vec
                else:
                    pre_gradients_vec = pre_gan_gradients_vec + pre_mi_gradients_vec

                pre_gradients_flatten = tf.split(pre_gradients_vec,
                                                 self._trainable_pre_weight_dims)
                pre_gradients = [tf.reshape(gradient, dim) for gradient, dim in
                                 zip(pre_gradients_flatten, self._trainable_pre_weight_shapes)]
                disc_gradients = tape.gradient(loss, self._disc.trainable_weights)
                gradients = disc_gradients + pre_gradients
                del tape
                return gradients, losses
        else:
            raise NotImplementedError

        def train(l_ims_batch, e_ims_batch, s_e_ims_batch=None, l_prior_ims_batch=None, e_prior_ims_batch=None):
            gradients, losses = get_gradients(l_ims_batch=l_ims_batch,
                                              e_ims_batch=e_ims_batch,
                                              s_e_ims_batch=s_e_ims_batch,
                                              l_prior_ims_batch=l_prior_ims_batch,
                                              e_prior_ims_batch=e_prior_ims_batch)
            optimizer.apply_gradients(zip(gradients, self._disc.trainable_weights + self.get_pre_weights()))
            gan_loss, gp_loss, kl, entropy = losses
            return gan_loss, gp_loss, kl, entropy

        return tf.function(train)

    @staticmethod
    def make_dual_mi_constant_update(log_mi_constant, max_mi, optimizer):
        def update_dual_mi_constant(mi_estimate):
            mi_diff = max_mi - mi_estimate
            with tf.GradientTape() as tape:
                mi_dual_loss = log_mi_constant * tf.stop_gradient(mi_diff)
                gradients = tape.gradient(mi_dual_loss, [log_mi_constant])
            optimizer.apply_gradients(zip(gradients, [log_mi_constant]))

        return update_dual_mi_constant

    def get_random_im_batches(self,
                              dac_buffer,
                              d_e_batch_size,
                              d_l_batch_size,
                              num_f_l_data=0,
                              weighted=False):
        if self._weighted_dac_buffer and not weighted:
            l_batch = dac_buffer.get_random_batch(d_l_batch_size - num_f_l_data, False,
                                                  weighted=False)
        else:
            l_batch = dac_buffer.get_random_batch(d_l_batch_size - num_f_l_data, False)
        l_ims = l_batch['ims']
        if num_f_l_data > 0:
            f_l_batch = self._exp_buff.get_random_batch(num_f_l_data)
            f_l_ims = f_l_batch['ims']
            l_ims = tf.concat([l_ims, f_l_ims], axis=0)
        e_batch = self._exp_buff.get_random_batch(d_e_batch_size)
        e_ims = e_batch['ims']
        return l_ims, e_ims

    def get_random_im_prior_batches(self,
                                    d_e_batch_size,
                                    d_l_batch_size, ):
        if self._prior_domains_data or self._negative_priors:
            l_batch = self._pr_age_buff.get_random_batch(d_l_batch_size)
            l_ims = l_batch['ims']
            e_batch = self._pr_exp_buff.get_random_batch(d_e_batch_size)
            e_ims = e_batch['ims']
            return l_ims, e_ims
        else:
            return None, None

    def get_random_expert_shuffled_batch(self,
                                         d_s_e_batch_size, ):
        if self._expert_shuffle:
            e_batch = self._exp_buff.get_random_batch(d_s_e_batch_size)
            e_ims = e_batch['ims']
            ims_shape = list(e_ims.shape)
            flattened_shape = [d_s_e_batch_size * self._past_frames] + ims_shape[2:]
            e_ims_reshaped = tf.reshape(e_ims, flattened_shape)
            e_ims_shuffled = tf.random.shuffle(e_ims_reshaped)
            s_e_ims = tf.reshape(e_ims_shuffled, ims_shape)
            return s_e_ims
        else:
            return None

    def train(self, dac_buffer, l_batch_size=128, l_updates=1, l_act_delay=1,
              d_updates=1, kl_updates=1, mi_updates=1, d_e_batch_size=128, d_e_s_batch_size=128,
              d_l_batch_size=128, d_mixture_proportion=0.0,
              get_distribution_stats=False):
        kl = []
        entropy = []
        mi = []
        pr_mi = []
        if self._mi_constant > 0 or self._adaptive_kl:
            m_updates_per_d = mi_updates // d_updates
            assert m_updates_per_d > 0, 'mi updates is less than d updates'
        if self._kl_estimation == 'DV':
            kl_updates_per_d = kl_updates // d_updates
            assert kl_updates_per_d > 0, 'DV kl updates is less than d updates'
        num_f_l_data = tf.cast(d_l_batch_size * d_mixture_proportion, tf.int32)
        for _ in range(d_updates):
            l_ims, e_ims = self.get_random_im_batches(dac_buffer=dac_buffer,
                                                      d_e_batch_size=d_e_batch_size,
                                                      d_l_batch_size=d_l_batch_size,
                                                      num_f_l_data=num_f_l_data,
                                                      weighted=False)

            s_e_ims = self.get_random_expert_shuffled_batch(d_e_s_batch_size)

            l_prior_ims, e_prior_ims = self.get_random_im_prior_batches(d_e_batch_size=d_e_batch_size,
                                                                        d_l_batch_size=d_l_batch_size)
            c_gan_loss, c_gp_loss, c_kl, c_entropy = self._train_disc(l_ims, e_ims, s_e_ims, l_prior_ims, e_prior_ims)
            if get_distribution_stats:
                kl.append(c_kl)
                entropy.append(c_entropy)
            if self._mi_constant > 0.0 or self._adaptive_kl or self._prior_domains_data:
                for _ in range(m_updates_per_d):
                    l_ims, e_ims = self.get_random_im_batches(dac_buffer=dac_buffer,
                                                              d_e_batch_size=d_e_batch_size,
                                                              d_l_batch_size=d_l_batch_size,
                                                              num_f_l_data=0,
                                                              weighted=False)
                    mi_loss = self._train_mi(l_ims, e_ims)
                    if self._prior_domains_data:
                        l_ims, e_ims = self.get_random_im_prior_batches(d_e_batch_size=d_e_batch_size,
                                                                        d_l_batch_size=d_l_batch_size)
                        mi_prior_loss = self._train_mi(l_ims, e_ims)
                        mi_prior_est = -1 * mi_prior_loss
                    if self._double_mi:
                        l_ims, e_ims = self.get_random_im_batches(dac_buffer=dac_buffer,
                                                                  d_e_batch_size=d_e_batch_size,
                                                                  d_l_batch_size=d_l_batch_size,
                                                                  num_f_l_data=0,
                                                                  weighted=False)
                        mi2_loss = self._train_mi2(l_ims, e_ims)
                        mi_est = -1 * tf.minimum(mi_loss, mi2_loss)
                        if self._clip_mi_predictions:
                            mi_est = tf.clip_by_value(mi_est, 0.0, 1.0)
                        mi.append(mi_est)
                        if self._prior_domains_data:
                            l_ims, e_ims = self.get_random_im_prior_batches(d_e_batch_size=d_e_batch_size,
                                                                            d_l_batch_size=d_l_batch_size)
                            mi2_prior_loss = self._train_mi2(l_ims, e_ims)
                            mi2_prior_est = -1 * mi2_prior_loss
                            mi_prior_est = tf.maximum(mi_prior_est, mi2_prior_est)
                            if self._clip_mi_predictions:
                                mi_prior_est = tf.clip_by_value(mi_prior_est, 0.0, 1.0)
                            pr_mi.append(mi_prior_est)
                    else:
                        mi_est = -1 * mi_loss
                        if self._clip_mi_predictions:
                            mi_est = tf.clip_by_value(mi_est, 0.0, 1.0)
                        mi.append(mi_est)
                        if self._prior_domains_data:
                            if self._clip_mi_predictions:
                                mi_prior_est = tf.clip_by_value(mi_prior_est, 0.0, 1.0)
                            pr_mi.append(mi_prior_est)
                if self._use_dual_mi:
                    self.update_dual_mi_constant(mi_est)
                    self._log_mi_constant.assign(tf.clip_by_value(self._log_mi_constant,
                                                                  self._log_min_mi_constant,
                                                                  self._log_max_mi_constant))
                    self._mi_constant = tf.exp(self._log_mi_constant)
                if self._prior_domains_data:
                    self.update_dual_mi_prior_constant(mi_prior_est)
                    self._log_mi_prior_constant.assign(tf.maximum(self._log_mi_prior_constant,
                                                                  self._log_min_mi_prior_constant))
                    self._mi_prior_constant = tf.exp(self._log_mi_prior_constant)

            if self._kl_estimation == 'DV':
                for _ in range(kl_updates_per_d):
                    l_ims, e_ims = self.get_random_im_batches(dac_buffer=dac_buffer,
                                                              d_e_batch_size=d_e_batch_size,
                                                              d_l_batch_size=d_l_batch_size,
                                                              num_f_l_data=0)
                    kl_est = self._train_kl(l_ims, e_ims)

        if self._mi_constant > 0.0:
            average_mi = tf.reduce_mean(mi)
        if self._adaptive_mi:
            if not self._use_dual_mi:
                if average_mi > self._max_mi:
                    self._mi_constant.assign(self._mi_constant * 1.5)
                elif average_mi < self._min_mi:
                    self._mi_constant.assign(self._mi_constant / 1.5)
                self._mi_constant.assign(tf.clip_by_value(self._mi_constant,
                                                          self._min_mi_constant,
                                                          self._max_mi_constant))
        if self._adaptive_kl:
            average_mi = tf.reduce_mean(mi)
            if average_mi > self._max_mi:
                self._kl_regularization.assign(self._kl_regularization * 1.5)
            elif average_mi < self._min_mi:
                self._kl_regularization.assign(self._kl_regularization / 1.5)
            self._kl_regularization.assign(tf.clip_by_value(self._kl_regularization,
                                                            self._min_kl_regularization,
                                                            self._max_kl_regularization))
        if self._prior_domains_data:
            average_pr_mi = tf.reduce_mean(pr_mi)

        if self._weighted_dac_buffer:
            dac_buffer.update_buffer()
        self._agent.train(dac_buffer, l_batch_size, l_updates, l_act_delay)
        if get_distribution_stats:
            return kl, entropy, mi


class DisentanglePreprocessing(tf.keras.layers.Layer):
    def __init__(self, preprocessing_layers, ds_dims=1):
        super(DisentanglePreprocessing, self).__init__()
        self._pre_layers = preprocessing_layers
        self._ds_dims = ds_dims

    def layers_out(self, inputs):
        out = inputs
        for layer in self._pre_layers:
            out = layer(out)
        return out

    @tf.function
    def call(self, inputs):
        input_shape = inputs.get_shape()
        out = tf.reshape(inputs, [input_shape[0] * input_shape[1]] + list(input_shape[2:]))
        out = self.layers_out(out)
        out_shape = out.get_shape()
        out = tf.reshape(out, list(input_shape[:2]) + list(out_shape[1:]))
        di_dims = out_shape[-1] - self._ds_dims
        out_ds, out_di = tf.split(out, [self._ds_dims, di_dims], -1)
        out_ds = tf.transpose(out_ds, perm=[0, 2, 3, 1, 4])
        out_di = tf.transpose(out_di, perm=[0, 2, 3, 1, 4])
        reshaped_ds = tf.reshape(out_ds,
                                 [input_shape[0]] + list(out_shape[1:-1]) +
                                 [self._ds_dims * input_shape[1]])
        reshaped_di = tf.reshape(out_di,
                                 [input_shape[0]] + list(out_shape[1:-1]) +
                                 [di_dims * input_shape[1]])
        return reshaped_ds, reshaped_di
