"""Neural network architectures."""

import gin
import tensorflow as tf


def mlp(hidden_sizes, activation, name=None, **kwargs):
    def is_last(i):
        return i == len(hidden_sizes) - 1

    return tf.keras.Sequential([
        tf.keras.layers.Dense(
            size,
            activation=activation if not is_last(i) else None,
            **kwargs,
        )
        for i, size in enumerate(hidden_sizes)
    ], name)


# pylint: disable=keyword-arg-before-vararg
def cnn(conv_filters=(64, 64, 64),
        conv_kernel_sizes=(3, 3, 3),
        conv_strides=(1, 1, 1),
        max_pool_sizes=(2, 2, 2),
        max_pool_strides=(2, 2, 2),
        padding='SAME',
        activation_fn=tf.keras.activations.elu,
        name='cnn',
        *args,
        **kwargs):
    """TODO: fill me."""
    # pylint: disable=missing-return-doc
    def conv_block(conv_filter, conv_kernel_size, conv_stride,
                   max_pool_size, max_pool_stride, name='conv_block'):
        block_parts = [
            tf.keras.layers.Conv2D(
                filters=conv_filter,
                kernel_size=conv_kernel_size,
                strides=conv_stride,
                padding=padding,
                activation=activation_fn,
                *args,
                **kwargs),
            tf.keras.layers.MaxPool2D(
                pool_size=max_pool_size,
                strides=max_pool_stride,
                padding=padding,
            )
        ]

        block = tf.keras.Sequential(block_parts, name=name)
        return block

    model = tf.keras.Sequential((
        *[
            conv_block(
                conv_filter,
                conv_kernel_size,
                conv_stride,
                max_pool_size,
                max_pool_stride,
                name=f'conv_block_{i}'
            )
            for i, (conv_filter, conv_kernel_size, conv_stride, max_pool_size,
                    max_pool_stride) in
            enumerate(zip(conv_filters, conv_kernel_sizes, conv_strides,
                          max_pool_sizes, max_pool_strides))
        ],
        tf.keras.layers.Flatten(),

    ), name=name)

    return model


class NoisyChannelLayer(tf.keras.layers.Layer):
    """Add noise to message."""

    @staticmethod
    def _noise_uniform(p, alphabet_size):
        matrix = tf.eye(alphabet_size, dtype='float32')
        alphabet_size = tf.cast(alphabet_size, 'float32')
        matrix *= (1. - p - p / (alphabet_size - 1.))
        matrix += p / (alphabet_size - 1.)
        return matrix

    def __init__(self,
                 alphabet_size=5,
                 noise_probability=0.0,
                 ):

        super().__init__()
        self._alphabet_size = alphabet_size
        self._noise_probability = tf.Variable(noise_probability,
                                              dtype='float32', trainable=False)


    # pylint: disable=arguments-differ
    def call(self, inputs, training=None, mask=None,
             gradient_trick=tf.constant(False)):
        del training, mask
        noise_matrix = NoisyChannelLayer._noise_uniform(self._noise_probability,
                                                         self._alphabet_size)
        probs = tf.nn.softmax(inputs)
        noisy_probs = tf.matmul(probs, noise_matrix)
        if gradient_trick:
            noisy_probs = inputs + tf.stop_gradient(noisy_probs - inputs)
        noisy_logits = tf.math.log(noisy_probs + 1e-20)
        return noisy_logits

    @property
    def noise_probability(self):
        return float(self._noise_probability.numpy())

    @noise_probability.setter
    def noise_probability(self, noise_probability):
        self._noise_probability.assign(noise_probability)


@gin.configurable
class Sender(tf.keras.Model):
    """Encodes scene into a message."""

    def __init__(self,
                 message_len=2,
                 alphabet_size=5,
                 conv_filters=(64, 64, 64),
                 conv_kernel_sizes=(3, 3, 3),
                 conv_strides=(1, 1, 1),
                 max_pool_sizes=(2, 2, 2),
                 max_pool_strides=(2, 2, 2),
                 padding='SAME',
                 embedding=64,
                 weight_decay=1e-3,
                 activation_fn=tf.keras.activations.elu,
                 input_dim=None,
                 restore_pretrain_fname=None,
                 log_regularizer=0.0):
        super().__init__()

        self._message_len = message_len
        self._alphabet_size = alphabet_size
        self._log_regularizer = log_regularizer

        # cnn block
        self._cnn = cnn(
            conv_filters=conv_filters,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            max_pool_sizes=max_pool_sizes,
            max_pool_strides=max_pool_strides,
            padding=padding,
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
            trainable=(restore_pretrain_fname is None),
            name='sender_cnn'
        )

        if restore_pretrain_fname:
            # to be able to use set_weights we need to set the weights
            self._cnn.build((None,) + input_dim)
            pretrained_model = tf.keras.models.load_model(
                restore_pretrain_fname)
            self._cnn.set_weights(pretrained_model._cnn.get_weights())
            self._cnn.add(tf.keras.layers.Lambda(tf.stop_gradient))

        # first dense part after cnn (can be pretrained)
        self._mlp = mlp(
            (embedding, embedding, alphabet_size * message_len),
            activation=activation_fn,
            name='sender_mlp',
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
        )

        self._body = tf.keras.Sequential([
            self._cnn,
            tf.keras.layers.Flatten(),
            self._mlp,
            tf.keras.layers.Reshape((message_len, alphabet_size)),
        ])

        self._noisy_channel = NoisyChannelLayer(
            alphabet_size=alphabet_size,
        )

    @property
    def noise_probability(self):
        return self._noisy_channel.noise_probability

    @noise_probability.setter
    def noise_probability(self, noise_probability_):
        self._noisy_channel.noise_probability = noise_probability_

    # pylint: disable=arguments-differ
    def call(self, inputs, training=None, mask=None,
             gradient_trick=tf.constant(False)):
        # INFO: we assume that sender's alphabet
        # size is the same for each message
        logits = self._body(inputs)  # (B, message_len, alphabet_size)

        # compute noisy logits
        noisy_logits = self._noisy_channel(logits, gradient_trick)

        # cross_entropy between uniform and latent
        # (a.k.a. message logits before noisy channel)
        probs = tf.nn.softmax(logits)
        log_probs = tf.math.log(probs + self._log_regularizer)
        latent_cross_entropy = -tf.reduce_mean(log_probs, axis=[0, 2])
        latent_cross_entropy = tf.reduce_sum(latent_cross_entropy)
        latent_entropy = -tf.reduce_mean(tf.reduce_sum(probs*log_probs, axis=2))

        # logits, noisy_logits are list of tensors (B, features)
        return logits, noisy_logits, latent_cross_entropy, latent_entropy


@gin.configurable
class Receiver(tf.keras.Model):
    """Receiver. Decodes message."""

    def __init__(self,
                 message_len,
                 alphabet_size,
                 hidden_sizes,
                 filter_size,
                 weight_decay,
                 activation_fn=tf.keras.activations.elu):
        super().__init__()

        self._activation_fn = activation_fn
        self._filter_size = filter_size

        self._flatten = tf.keras.layers.Flatten()

        self._filter_1 = tf.keras.layers.Dense(
            filter_size,
            activation=None,
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
            name='receiver_filter_1'
        )
        self._filter_2 = tf.keras.layers.Dense(
            filter_size,
            activation=None,
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
            name='receiver_filter_2'
        )
        self._filter_3 = tf.keras.layers.Dense(
            filter_size,
            activation=activation_fn,
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
            name='receiver_filter_3'
        )

        hidden_sizes = list(hidden_sizes) + [message_len * alphabet_size]
        self._body = tf.keras.Sequential([
            mlp(
                hidden_sizes,
                activation=activation_fn,
                name='receiver_mlp',
                kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
            ),
            tf.keras.layers.Reshape((message_len, alphabet_size)),
        ])

    def call(self, inputs, training=None, mask=None):
        # inputs has a shape (B, message_len, alphabet)
        x = self._flatten(inputs)
        if self._filter_size:
            x1 = self._filter_1(x)
            x2 = self._filter_2(1 - x)
            x = self._activation_fn(x1 + x2)
            x = self._filter_3(x)

        preds_logits = self._body(x)
        preds = tf.nn.softmax(preds_logits, axis=-1)
        preds = tf.cast(tf.math.argmax(preds, axis=-1), 'int32')

        return preds, preds_logits


@gin.configurable
class ReferenceGameReceiver(tf.keras.Model):
    """Receiver for a reference game: the sender sees a target image,
    communicates it to the receiver. Then the receiver sees `num_distractors`+1
    candidates as well as receives sender's messageimages and must indicate the
    target image. Thus, the instance of this class takes two tensors: message
    and visual inputs (`num_distractors`+1 images stacked together) and returns
    two tensors: predictions (the index of target image) and raw logits. The
    logit for each candidate image results from an inner product of image
    embedding for that image and a message embedding.
    """

    def __init__(
            self,
            visual_input_dim,
            message_mlp_hidden_sizes,
            message_filter_size,
            weight_decay,
            num_candidates: int,
            conv_filters=(64, 64, 64),
            conv_kernel_sizes=(3, 3, 3),
            conv_strides=(1, 1, 1),
            max_pool_sizes=(2, 2, 2),
            max_pool_strides=(2, 2, 2),
            padding='SAME',
            embedding_size=64,
            restore_pretrain_fname=None,
            activation_fn=tf.keras.activations.elu,
    ):
        super().__init__()

        self._activation_fn = activation_fn
        self._filter_size = message_filter_size
        self._num_candidates = num_candidates

        # CNN block
        self._cnn = cnn(
            conv_filters=conv_filters,
            conv_kernel_sizes=conv_kernel_sizes,
            conv_strides=conv_strides,
            max_pool_sizes=max_pool_sizes,
            max_pool_strides=max_pool_strides,
            padding=padding,
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
            trainable=(restore_pretrain_fname is None),
            name='receiver_cnn'
        )
        if restore_pretrain_fname:
            # to be able to use set_weights we need to set the weights
            self._cnn.build((None,) + visual_input_dim)
            pretrained_model = tf.keras.models.load_model(
                restore_pretrain_fname)
            self._cnn.set_weights(pretrained_model._cnn.get_weights())
            self._cnn.add(tf.keras.layers.Lambda(tf.stop_gradient))

        # Encoding image into an image embedding
        self._image_mlp = mlp(
            (embedding_size, embedding_size),
            activation=activation_fn,
            name='receiver_image_mlp',
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
        )

        self._image_encoder = tf.keras.Sequential([
            self._cnn,
            tf.keras.layers.Flatten(),
            self._image_mlp,
        ])

        # Message decoding
        self._flatten = tf.keras.layers.Flatten()

        self._filter_1 = tf.keras.layers.Dense(
            message_filter_size,
            activation=None,
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
            name='receiver_filter_1'
        )
        self._filter_2 = tf.keras.layers.Dense(
            message_filter_size,
            activation=None,
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
            name='receiver_filter_2'
        )
        self._filter_3 = tf.keras.layers.Dense(
            message_filter_size,
            activation=activation_fn,
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
            name='receiver_filter_3'
        )

        # Decoding the message into a message embedding
        self._message_mlp = mlp(
            message_mlp_hidden_sizes,
            activation=activation_fn,
            name='receiver_message_mlp',
            kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
        )

    def call(  # pylint: disable=arguments-differ
        self,
        message,  # (B, message_len, alphabet_size)
        visual_inputs,  # (B, num_candidates, ...)
    ):
        x = self._flatten(message)
        if self._filter_size:
            x1 = self._filter_1(x)
            x2 = self._filter_2(1 - x)
            x = self._activation_fn(x1 + x2)
            x = self._filter_3(x)

        message_embedding = self._message_mlp(x)  # (B, embedding_size)
        image_embeddings = []
        for i in range(self._num_candidates):
            # (B, embedding_size)
            image_embedding = self._image_encoder(visual_inputs[:, i, ...])
            image_embeddings.append(image_embedding)
        # (B, num_candidates, embedding_size)
        image_embeddings = tf.stack(image_embeddings, axis=1)
        # (B, num_candidates)
        logits = tf.squeeze(tf.matmul(
            image_embeddings,
            tf.expand_dims(message_embedding, axis=-1)
        ), axis=-1)
        preds = tf.nn.softmax(logits, axis=-1)
        preds = tf.cast(tf.math.argmax(preds, axis=-1), 'int32')
        return preds, logits


@tf.function
def gumbel_softmax(logits, tau, straight_through=tf.constant(True),
                   eps=tf.constant(1e-20)):
    """Calculate gumbel softmax"""

    uniform = tf.random.uniform(logits.shape, dtype=logits.dtype)
    # pylint: disable=invalid-unary-operand-type # pylint error
    g = tf.math.log(-tf.math.log(uniform + eps) + eps)
    y = tf.nn.softmax((logits - g) / tau, axis=-1)
    # pylint: disable=no-value-for-parameter
    y_hard = tf.cast(tf.one_hot(tf.argmax(y, -1), y.shape[-1]), y.dtype)
    y_hard = tf.stop_gradient(y_hard - y) + y
    return tf.where(straight_through, y_hard, y)  # select y or y_hard


@tf.function
def gumbel_sample(x):
    # Taken from Baselines
    u = tf.random.uniform(tf.shape(x), dtype=x.dtype)
    # pylint: disable=invalid-unary-operand-type # pylint error
    return tf.argmax(x - tf.math.log(-tf.math.log(u)), axis=-1,
                     output_type=tf.int32)
