# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
    TransformerMonotonicDecoderLayer,
    TransformerMonotonicEncoderLayer,
)
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
    TransformerDecoder,
    TransformerEncoder,
    TransformerModel,
    base_architecture,
    transformer_iwslt_de_en,
    transformer_vaswani_wmt_en_de_big,
)


DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024


@register_model("transformer_unidirectional")
class TransformerUnidirectionalModel(TransformerModel):
    @classmethod
    def build_encoder(cls, args, src_dict, embed_tokens):
        return TransformerMonotonicEncoder(args, src_dict, embed_tokens)


@register_model("transformer_monotonic")
class TransformerMonotonicModel(TransformerModel):
    @classmethod
    def build_encoder(cls, args, src_dict, embed_tokens):
        return TransformerMonotonicEncoder(args, src_dict, embed_tokens)

    @classmethod
    def build_decoder(cls, args, tgt_dict, embed_tokens):
        return TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)

    def _indices_from_states(self, states):
        if type(states["indices"]["src"]) == list:
            if next(self.parameters()).is_cuda:
                tensor = torch.cuda.LongTensor
            else:
                tensor = torch.LongTensor

            src_indices = tensor(
                [states["indices"]["src"][: 1 + states["steps"]["src"]]]
            )

            tgt_indices = tensor(
                [[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
            )
        else:
            src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
            tgt_indices = states["indices"]["tgt"]

        return src_indices, None, tgt_indices

    def predict_from_states(self, states):
        decoder_states = self.decoder.output_layer(states["decoder_features"])
        lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True)

        index = lprobs.argmax(dim=-1)

        token = self.decoder.dictionary.string(index)

        return token, index[0, 0].item()

    def decision_from_states(self, states):
        """
        This funcion take states dictionary as input, and gives the agent
        a decision of whether read a token from server. Moreover, the decoder
        states are also calculated here so we can directly generate a target
        token without recompute every thing
        """

        self.eval()

        if len(states["tokens"]["src"]) == 0:
            return 0

        src_indices, src_lengths, tgt_indices = self._indices_from_states(states)

        # Update encoder states if needed
        if (
            "encoder_states" not in states
            or states["encoder_states"][0].size(1) <= states["steps"]["src"]
        ):
            encoder_out_dict = self.encoder(src_indices, src_lengths)
            states["encoder_states"] = encoder_out_dict
        else:
            encoder_out_dict = states["encoder_states"]

        # online means we still need tokens to feed the model
        states["model_states"]["online"] = not (
            states["finish_read"]
            and len(states["tokens"]["src"]) == states["steps"]["src"]
        )

        states["model_states"]["steps"] = states["steps"]

        x, outputs = self.decoder.forward(
            prev_output_tokens=tgt_indices,
            encoder_out=encoder_out_dict,
            incremental_state=states["model_states"],
            features_only=True,
        )

        states["decoder_features"] = x

        return outputs["action"]


class TransformerMonotonicEncoder(TransformerEncoder):
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(args, dictionary, embed_tokens)

        self.dictionary = dictionary
        self.layers = nn.ModuleList([])
        self.layers.extend(
            [TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
        )


class TransformerMonotonicDecoder(TransformerDecoder):
    """
    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
    is a :class:`TransformerDecoderLayer`.

    Args:
        args (argparse.Namespace): parsed command-line arguments
        dictionary (~fairseq.data.Dictionary): decoding dictionary
        embed_tokens (torch.nn.Embedding): output embedding
        no_encoder_attn (bool, optional): whether to attend to encoder outputs
            (default: False).
    """

    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
        super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False)

        self.dictionary = dictionary
        self.layers = nn.ModuleList([])
        self.layers.extend(
            [
                TransformerMonotonicDecoderLayer(args, no_encoder_attn)
                for _ in range(args.decoder_layers)
            ]
        )

    def pre_attention(
        self, prev_output_tokens, encoder_out_dict, incremental_state=None
    ):
        positions = (
            self.embed_positions(
                prev_output_tokens,
                incremental_state=incremental_state,
            )
            if self.embed_positions is not None
            else None
        )

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions
        x = self.dropout_module(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        encoder_out = encoder_out_dict.encoder_out
        encoder_padding_mask = encoder_out_dict.encoder_padding_mask

        return x, encoder_out, encoder_padding_mask

    def post_attention(self, x):
        if self.layer_norm:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        return x

    def extract_features(
        self, prev_output_tokens, encoder_out, incremental_state=None, **unused
    ):
        """
        Similar to *forward* but only return features.

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
        """
        # incremental_state = None
        (x, encoder_outs, encoder_padding_mask) = self.pre_attention(
            prev_output_tokens, encoder_out, incremental_state
        )
        attn = None
        inner_states = [x]
        attn_list = []
        step_list = []

        for i, layer in enumerate(self.layers):

            x, attn, _ = layer(
                x=x,
                encoder_out=encoder_outs,
                encoder_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                self_attn_mask=self.buffered_future_mask(x)
                if incremental_state is None
                else None,
            )

            inner_states.append(x)
            attn_list.append(attn)

            if incremental_state is not None:
                curr_steps = layer.get_steps(incremental_state)
                step_list.append(curr_steps)

                if incremental_state.get("online", False):
                    p_choose = (
                        attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
                    )

                    new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)

                    if (new_steps >= incremental_state["steps"]["src"]).any():
                        # We need to prune the last self_attn saved_state
                        # if model decide not to read
                        # otherwise there will be duplicated saved_state
                        for j in range(i + 1):
                            self.layers[j].prune_incremental_state(incremental_state)

                        return x, {"action": 0}

        if incremental_state is not None and not incremental_state.get("online", False):
            # Here is for fast evaluation
            fastest_step = (
                torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1
            )

            if "fastest_step" in incremental_state:
                incremental_state["fastest_step"] = torch.cat(
                    [incremental_state["fastest_step"], fastest_step], dim=1
                )
            else:
                incremental_state["fastest_step"] = fastest_step

        x = self.post_attention(x)

        return x, {
            "action": 1,
            "attn_list": attn_list,
            "step_list": step_list,
            "encoder_out": encoder_out,
            "encoder_padding_mask": encoder_padding_mask,
        }

    def reorder_incremental_state(self, incremental_state, new_order):
        super().reorder_incremental_state(incremental_state, new_order)
        if "fastest_step" in incremental_state:
            incremental_state["fastest_step"] = incremental_state[
                "fastest_step"
            ].index_select(0, new_order)


@register_model_architecture("transformer_monotonic", "transformer_monotonic")
def base_monotonic_rchitecture(args):
    base_architecture(args)
    args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)


@register_model_architecture(
    "transformer_monotonic", "transformer_monotonic_iwslt_de_en"
)
def transformer_monotonic_iwslt_de_en(args):
    transformer_iwslt_de_en(args)
    base_monotonic_rchitecture(args)


# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture(
    "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
)
def transformer_monotonic_vaswani_wmt_en_de_big(args):
    transformer_vaswani_wmt_en_de_big(args)


@register_model_architecture(
    "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
)
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
    transformer_monotonic_vaswani_wmt_en_fr_big(args)


@register_model_architecture(
    "transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
)
def transformer_unidirectional_iwslt_de_en(args):
    transformer_iwslt_de_en(args)
