import torch
from allennlp.modules import Seq2SeqEncoder
from allennlp.modules.feedforward import FeedForward
from allennlp.modules.layer_norm import LayerNorm
from allennlp.nn.activations import Activation
from ..utils.utils import MaskedMultiHeadSelfAttention
from allennlp.nn.util import add_positional_features


class MaskedStackedSelfAttentionEncoder(Seq2SeqEncoder):
    # pylint: disable=line-too-long
    """
    Implements a stacked self-attention encoder similar to the Transformer
    architecture in `Attention is all you Need
    <https://www.semanticscholar.org/paper/Attention-Is-All-You-Need-Vaswani-Shazeer/0737da0767d77606169cbf4187b83e1ab62f6077>`_ .  # noqa

    This encoder combines 3 layers in a 'block':

    1. A 2 layer FeedForward network.
    2. Multi-headed self attention, which uses 2 learnt linear projections
       to perform a dot-product similarity between every pair of elements
       scaled by the square root of the sequence length.
    3. Layer Normalisation.

    These are then stacked into ``num_layers`` layers.

    Parameters
    ----------
    input_dim : ``int``, required.
        The input dimension of the encoder.
    hidden_dim : ``int``, required.
        The hidden dimension used for the _input_ to self attention layers
        and the _output_ from the feedforward layers.
    projection_dim : ``int``, required.
        The dimension of the linear projections for the self-attention layers.
    feedforward_hidden_dim : ``int``, required.
        The middle dimension of the FeedForward network. The input and output
        dimensions are fixed to ensure sizes match up for the self attention layers.
    num_layers : ``int``, required.
        The number of stacked self attention -> feedfoward -> layer normalisation blocks.
    num_attention_heads : ``int``, required.
        The number of attention heads to use per layer.
    use_positional_encoding: ``bool``, optional, (default = True)
        Whether to add sinusoidal frequencies to the input tensor. This is strongly recommended,
        as without this feature, the self attention layers have no idea of absolute or relative
        position (as they are just computing pairwise similarity between vectors of elements),
        which can be important features for many tasks.
    dropout_prob : ``float``, optional, (default = 0.2)
        The dropout probability for the feedforward network.
    """

    def __init__(
        self,
        input_dim,
        hidden_dim,
        projection_dim,
        feedforward_hidden_dim,
        num_layers,
        num_attention_heads,
        use_positional_encoding=True,
        dropout_prob=0.2,
    ):
        super(MaskedStackedSelfAttentionEncoder, self).__init__()

        self._use_positional_encoding = use_positional_encoding
        self._attention_layers = []
        self._feedfoward_layers = []
        self._layer_norm_layers = []
        self._feed_forward_layer_norm_layers = []

        feedfoward_input_dim = input_dim
        for i in range(num_layers):
            feedfoward = FeedForward(
                feedfoward_input_dim,
                activations=[Activation.by_name("relu")(), Activation.by_name("linear")()],
                hidden_dims=[feedforward_hidden_dim, hidden_dim],
                num_layers=2,
                dropout=dropout_prob,
            )

            self.add_module("feedforward_{i}".format(feedfoward))
            self._feedfoward_layers.append(feedfoward)

            feedforward_layer_norm = LayerNorm(feedfoward.get_input_dim())
            self.add_module("feedforward_layer_norm_{i}".format(feedforward_layer_norm))
            self._feed_forward_layer_norm_layers.append(feedforward_layer_norm)

            self_attention = MaskedMultiHeadSelfAttention(
                num_heads=num_attention_heads,
                input_dim=hidden_dim,
                attention_dim=projection_dim,
                values_dim=projection_dim,
            )
            self.add_module("self_attention_{i}".format(self_attention))
            self._attention_layers.append(self_attention)

            layer_norm = LayerNorm(self_attention.get_input_dim())
            self.add_module("layer_norm_{i}".format(layer_norm))
            self._layer_norm_layers.append(layer_norm)

            feedfoward_input_dim = hidden_dim

        self.dropout = torch.nn.Dropout(dropout_prob)
        self._input_dim = input_dim
        self._output_dim = self._attention_layers[-1].get_output_dim()
        self._output_layer_norm = LayerNorm(self._output_dim)

    def get_input_dim(self):
        return self._input_dim

    def get_output_dim(self):
        return self._output_dim

    def is_bidirectional(self):
        return 0

    def forward(self, inputs, mask):  # pylint: disable=arguments-differ
        if self._use_positional_encoding:
            output = add_positional_features(inputs)
        else:
            output = inputs
        for (attention, feedforward, feedforward_layer_norm, layer_norm) in zip(
            self._attention_layers,
            self._feedfoward_layers,
            self._feed_forward_layer_norm_layers,
            self._layer_norm_layers,
        ):
            cached_input = output
            # Project output of attention encoder through a feedforward
            # network and back to the input size for the next layer.
            # shape (batch_size, timesteps, input_size)
            feedforward_output = feedforward(feedforward_layer_norm(output))
            feedforward_output = self.dropout(feedforward_output)
            if feedforward_output.size() == cached_input.size():
                # First layer might have the wrong size for highway
                # layers, so we exclude it here.
                feedforward_output += cached_input
            # shape (batch_size, sequence_length, hidden_dim)
            attention_output = attention(layer_norm(feedforward_output), mask)
            output = self.dropout(attention_output) + feedforward_output
        return self._output_layer_norm(output)

    @classmethod
    def from_params(cls, params):
        input_dim = params.pop_int("input_dim")
        hidden_dim = params.pop_int("hidden_dim")
        projection_dim = params.pop_int("projection_dim", None)
        feedforward_hidden_dim = params.pop_int("feedforward_hidden_dim")
        num_layers = params.pop_int("num_layers", 2)
        num_attention_heads = params.pop_int("num_attention_heads", 3)
        use_positional_encoding = params.pop_bool("use_positional_encoding", True)
        dropout_prob = params.pop_float("dropout_prob", 0.2)
        params.assert_empty(cls.__name__)

        return cls(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            feedforward_hidden_dim=feedforward_hidden_dim,
            projection_dim=projection_dim,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            use_positional_encoding=use_positional_encoding,
            dropout_prob=dropout_prob,
        )
