import tensorflow as tf
import tensorflow_probability as tfp

class Discriminator(tf.keras.layers.Layer):
    def __init__(self, layers, stab_const=0.0, rew='mixed'):
        super(Discriminator, self).__init__()
        self._dis_layers = layers
        self._sb = stab_const
        self._rew = rew

    def call(self, inputs):
        out = inputs
        for layer in self._dis_layers:
            out = layer(out)
        return out

    def get_prob(self, obs, acs):
        model_out = self.__call__(tf.concat([obs, acs], axis=-1))
        return tf.reshape(tf.sigmoid(model_out), [-1])

    def get_reward(self, obs, acs):
        if self._rew == 'linear':
            model_out = self.__call__(tf.concat([obs, acs], axis=-1))
            return tf.reshape(model_out, [-1])
        elif self._rew == 'positive':
            return -1 * tf.math.log(1 - self.get_prob(obs, acs) + self._sb)
        elif self._rew == 'negative':
            return tf.math.log(self.get_prob(obs, acs) + self._sb)
        return (tf.math.log(self.get_prob(obs, acs) + self._sb) -
                tf.math.log(1 - self.get_prob(obs, acs) + self._sb))


class VisualDiscriminator(tf.keras.layers.Layer):
    def __init__(self, layers, stab_const=0.0, rew='mixed'):
        super(VisualDiscriminator, self).__init__()
        self._dis_layers = layers
        self._sb = stab_const
        self._rew = rew

    def call(self, inputs):
        out = inputs
        for layer in self._dis_layers:
            out = layer(out)
        return out

    def get_prob(self, ims):
        model_out = self.__call__(ims)
        return tf.reshape(tf.sigmoid(model_out), [-1])

    def get_reward(self, ims):
        if self._rew == 'positive':
            return -1 * tf.math.log(1 - self.get_prob(ims) + self._sb)
        elif self._rew == 'negative':
            return tf.math.log(self.get_prob(ims) + self._sb)
        return (tf.math.log(self.get_prob(ims) + self._sb) -
                tf.math.log(1 - self.get_prob(ims) + self._sb))


class KLPreprocessing(tf.keras.layers.Layer):
    def __init__(self, preprocessing_layers, ):
        super(KLPreprocessing, self).__init__()
        self._pre_layers = preprocessing_layers

    @tf.function
    def call(self, inputs):

        out = inputs
        for layer in self._pre_layers:
            out = layer(out)
        return out

class FactoredPreprocessing(tf.keras.layers.Layer):
    def __init__(self, preprocessing_layers, ):
        super(FactoredPreprocessing, self).__init__()
        self._pre_layers = preprocessing_layers

    def layers_out(self, inputs):
        out = inputs
        for layer in self._pre_layers:
            out = layer(out)
        return out

    @tf.function
    def call(self, inputs):
        input_shape = inputs.get_shape()
        out = tf.reshape(inputs, [input_shape[0] * input_shape[1]] + list(input_shape[2:]))
        out = self.layers_out(out)
        out_shape = out.get_shape()
        out = tf.reshape(out, [input_shape[0], input_shape[1]*out_shape[1]])
        return out

class LatentKLPreprocessing(tf.keras.layers.Layer):
    def __init__(self, preprocessing_layers, scale_stddev=1):
        super(LatentKLPreprocessing, self).__init__()
        self._pre_layers = preprocessing_layers
        self._scale_stddev = scale_stddev

    def layers_out(self, inputs):
        out = inputs
        for layer in self._pre_layers:
            out = layer(out)
        return out

    def dist(self, inputs):
        out = self.layers_out(inputs)
        mean, log_stddev = tf.split(out, 2, axis=-1)
        stddev = tf.exp(tf.nn.tanh(log_stddev))*self._scale_stddev
        return tfp.distributions.Normal(loc=mean, scale=stddev)


    @tf.function
    def call(self, inputs):
        input_shape = inputs.get_shape()
        out = tf.reshape(inputs, [input_shape[0] * input_shape[1]] + list(input_shape[2:]))
        dist = self.dist(out)
        samples = dist.sample()
        out_shape = samples.get_shape()
        samples = tf.reshape(samples, [input_shape[0], input_shape[1] * out_shape[1]])
        return samples

    @tf.function
    def get_distribution_info(self, inputs):
        input_shape = inputs.get_shape()
        out = tf.reshape(inputs, [input_shape[0] * input_shape[1]] + list(input_shape[2:]))
        out = self.layers_out(out)
        mean, log_stddev = tf.split(out, 2, axis=-1)
        stddev = tf.exp(tf.nn.tanh(log_stddev))*self._scale_stddev
        out_shape = mean.get_shape()
        reshaped_mean = tf.reshape(mean, [input_shape[0], input_shape[1] * out_shape[1]])
        reshaped_stddev = tf.reshape(stddev, [input_shape[0], input_shape[1] * out_shape[1]])
        return reshaped_mean, reshaped_stddev


class RegularizedDiscriminator(VisualDiscriminator):
    def __init__(self, layers, stab_const=0.0, rew='mixed'):
        super(RegularizedDiscriminator, self).__init__(layers, stab_const, rew)


class LatentDiscriminator(VisualDiscriminator):
    def __init__(self, layers, stab_const=0.0, rew='mixed'):
        super(LatentDiscriminator, self).__init__(layers, stab_const, rew)


class GAIL(tf.keras.Model):
    def __init__(self, agent, discriminator, expert_buffer,
                 d_optimizer=tf.keras.optimizers.Adam(1e-3), label_smoothing=0.0,
                 stab_const=1e-7, grad_penalty=0.0):
        super(GAIL, self).__init__()
        self._disc = discriminator
        self._agent = agent

        self._exp_buff = expert_buffer
        self._sb = stab_const
        self._grad_penalty = grad_penalty
        self._train_disc = self.make_disc_training_op(d_optimizer, label_smoothing)

    def call(self, inputs):
        out = self._agent(inputs)
        out['rew'] = self._disc.get_reward(inputs, out['act'])
        return out

    def make_disc_training_op(self, optimizer, lb):
        if self._grad_penalty > 0.0:
            def train(l_obs_batch, l_act_batch, e_obs_batch, e_act_batch):
                with tf.GradientTape() as tape:
                    alpha = tf.random.uniform((tf.shape(l_obs_batch)[0], 1))
                    m_obs_batch = l_obs_batch * alpha + e_obs_batch * (1 - alpha)
                    m_act_batch = l_act_batch * alpha + e_act_batch * (1 - alpha)
                    l_disc_prob = self._disc.get_prob(l_obs_batch, l_act_batch) + self._sb
                    e_disc_prob = self._disc.get_prob(e_obs_batch, e_act_batch) + self._sb
                    labels = tf.concat([tf.zeros_like(l_disc_prob),
                                        tf.ones_like(e_disc_prob)], axis=0)
                    probs = tf.concat([l_disc_prob, e_disc_prob], axis=0)
                    gan_loss = tf.losses.binary_crossentropy(labels, probs, label_smoothing=lb)
                    with tf.GradientTape() as tape2:
                        tape2.watch(m_obs_batch)
                        tape2.watch(m_act_batch)
                        m_disc_prob = self._disc.get_prob(m_obs_batch, m_act_batch)
                        grad_obs = tape2.gradient(m_disc_prob, [m_obs_batch, m_act_batch])
                    grad_obs = tf.concat(grad_obs, axis=1)
                    gp_loss = tf.reduce_mean(tf.pow(tf.norm(grad_obs, axis=1) - 1, 2))
                    loss = gan_loss + self._grad_penalty * gp_loss
                    gradients = tape.gradient(loss, self._disc.trainable_weights)
                optimizer.apply_gradients(zip(gradients, self._disc.trainable_weights))
                return loss
        else:
            def train(l_obs_batch, l_act_batch, e_obs_batch, e_act_batch):
                with tf.GradientTape() as tape:
                    l_disc_prob = self._disc.get_prob(l_obs_batch, l_act_batch) + self._sb
                    e_disc_prob = self._disc.get_prob(e_obs_batch, e_act_batch) + self._sb
                    labels = tf.concat([tf.zeros_like(l_disc_prob),
                                        tf.ones_like(e_disc_prob)], axis=0)
                    probs = tf.concat([l_disc_prob, e_disc_prob], axis=0)
                    loss = tf.losses.binary_crossentropy(labels, probs, label_smoothing=lb)
                    gradients = tape.gradient(loss, self._disc.trainable_weights)
                optimizer.apply_gradients(zip(gradients, self._disc.trainable_weights))
                return loss

        return tf.function(train)

    def train(self, dac_buffer, l_batch_size=128, l_updates=1, l_act_delay=1,
              d_updates=1, d_e_batch_size=128, d_l_batch_size=128,
              d_mixture_proportion=0.0):
        for _ in range(d_updates):
            l_batch = dac_buffer.get_random_batch(tf.cast(
                d_l_batch_size * (1 - d_mixture_proportion), tf.int32), False)
            l_obs, l_acs = l_batch['obs'], l_batch['act']
            num_f_l_data = tf.cast(d_l_batch_size * d_mixture_proportion, tf.int32)
            if num_f_l_data > 0:
                f_l_batch = self._exp_buff.get_random_batch(num_f_l_data)
                f_l_obs, f_l_acs = f_l_batch['obs'], f_l_batch['act']
                l_obs = tf.concat([l_obs, f_l_obs], axis=0)
                l_acs = tf.concat([l_acs, f_l_acs], axis=0)
            e_batch = self._exp_buff.get_random_batch(d_e_batch_size)
            e_obs, e_acs, = e_batch['obs'], e_batch['act']
            self._train_disc(l_obs, l_acs, e_obs, e_acs)
        self._agent.train(dac_buffer, l_batch_size, l_updates, l_act_delay)
