from indigo.nn.wrappers.layer import Layer
from indigo.nn.base.block import Block
from indigo.nn.base.attention_with_bias import AttentionWithBias
from indigo.permutation_utils import pt_permutation_to_relative_l2r
import tensorflow as tf


class EncoderWithPositionLayer(Layer):

    def __init__(self,
                 input_size,
                 hidden_size,
                 heads,
                 queries_dropout=0.,
                 keys_dropout=0.,
                 values_dropout=0.,
                 causal=True,
                 num_pos=1,
                 **kwargs):
        """Creates a Transformer encoder layer by applying a
        multi head self attention layer

        Arguments:

        input_size: int
            the number of units in the input tensor to this layer
            also the output size of the model
        hidden_size: int
            the number of units in the hidden variables used
            in each multi head attention layer
        heads: int
            the number of heads in each multi head attention layer
            a good default is 4 or 8
        queries_dropout: float
            the ratio of units to drop during training to the
            number of units in each attention layer
        keys_dropout: float
            the ratio of units to drop during training to the
            number of units in each attention layer
        values_dropout: float
            the ratio of units to drop during training to the
            number of units in each attention layer
        causal: bool
            specifies is the transformer should decoding using
            a causal mask to preserve the auto regressive property
        num_pos: int
            number of relative position"""
        super(EncoderWithPositionLayer, self).__init__()

        # the core attention and processing variables
        self.block0 = Block(hidden_size, input_size * 3, **kwargs)
        self.pos_embedding = tf.keras.layers.Dense(input_size, **kwargs)
        self.attention = AttentionWithBias(queries_dropout=queries_dropout,
                                   keys_dropout=keys_dropout,
                                   values_dropout=values_dropout,
                                   causal=causal)
        self.block1 = Block(hidden_size, input_size, **kwargs)

        # these parameters need to be stored so that
        # tf.layers.model.save_model works
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.heads = heads
        self.queries_dropout = queries_dropout
        self.keys_dropout = keys_dropout
        self.values_dropout = values_dropout
        self.causal = causal
        self.num_pos = num_pos
        self.kwargs = kwargs

    def call(self, inputs, **kwargs):
        """Runs a forward pass on a multi head attention layer
        inputs is an instance of TransformerInput

        Arguments:

        inputs: TransformerInput
            a dataclass instance that contains queries, keys
            and values along with masks

        Returns:

        outputs: TransformerInput
            the result of applying a multi head attention mechanism
            same shape as inputs"""

        # unpack all the requires model inputs, some might be empty tensors:
        [queries, values, queries_mask, values_mask, ids, permutation,
         absolute_positions, relative_positions, pointer_labels, 
         logits_labels, partial_pos, pointer_probs, log_probs,
         object_detections, object_features, object_boxes] = inputs
        
        # pass the input through a feed forward processing block and
        # separate heads from channels
        shape, dim = tf.shape(values), self.input_size // self.heads
        x = self.block0(values, **kwargs)
        x = tf.transpose(tf.reshape(x, [
            shape[0], shape[1], self.heads, dim * 3]), [0, 2, 1, 3])
        
        rel = pt_permutation_to_relative_l2r(1, shape[1], tf.constant(self.num_pos))
        # add a position-conditioned bias to the attention scores
        # in log-space: https://arxiv.org/pdf/1902.01370.pdf
        pos = self.pos_embedding(rel, **kwargs)
        pos = tf.transpose(tf.reshape(pos, [
            1, shape[1], shape[1], self.heads, dim]), [0, 3, 1, 2, 4])
        bias = tf.squeeze(tf.matmul(
            tf.expand_dims(x[..., :dim], 3), pos, transpose_b=True), 3)  
        
        # pass the input through an attention processing block and
        # flatten the heads and channels
        mask = tf.expand_dims(values_mask, 1)
        x = self.attention([x[..., :dim], x[..., dim:2*dim], x[..., 2*dim:],
                            mask, mask, bias], **kwargs)
        x = tf.reshape(tf.transpose(x, [
            0, 2, 1, 3]), [shape[0], shape[1], self.heads * dim])

        # pass the outputs of the attention through another feed forward
        # processing block a residual connection
        values = values + x
        values = values + self.block1(values, **kwargs)
        
        return [queries, values, queries_mask, values_mask, ids, permutation,
                absolute_positions, relative_positions,
                pointer_labels, logits_labels, 
                partial_pos, pointer_probs, log_probs,
                object_detections, object_features, object_boxes]

    def get_config(self):
        """Creates a state dictionary that can be used to rebuild
        the layer in another python process

        Returns:

        config: dict
            a dictionary that contains all parameters to the
            layers base class and all class parameters"""

        # these are all that is needed to rebuild this class
        config = dict(input_size=self.input_size,
                      hidden_size=self.hidden_size,
                      heads=self.heads,
                      queries_dropout=self.queries_dropout,
                      keys_dropout=self.keys_dropout,
                      values_dropout=self.values_dropout,
                      causal=self.causal,
                      num_pos=self.num_pos,
                      ** self.kwargs)

        base_config = super(EncoderLayer, self).get_config()
        return dict(list(base_config.items()) +
                    list(config.items()))
