# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer

from . import build_monotonic_attention


class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
    def forward(self, x, encoder_padding_mask):
        seq_len, _, _ = x.size()
        attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
        attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf"))
        return super().forward(x, encoder_padding_mask, attn_mask)


class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
    def __init__(
        self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
    ):
        super().__init__(
            args,
            no_encoder_attn=True,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
        )

        assert args.simul_type is not None, "A --simul-type is needed."

        self.encoder_attn = build_monotonic_attention(args)
        self.encoder_attn_layer_norm = LayerNorm(
            self.embed_dim, export=getattr(args, "char_inputs", False)
        )

    def get_head_steps(self, incremental_state):
        return self.encoder_attn._get_monotonic_buffer(incremental_state).get(
            "head_step"
        )

    def prune_incremental_state(self, incremental_state):
        def prune(module):
            input_buffer = module._get_input_buffer(incremental_state)
            for key in ["prev_key", "prev_value"]:
                if input_buffer[key].size(2) > 1:
                    input_buffer[key] = input_buffer[key][:, :, :-1, :]
                else:
                    input_buffer = {}
                    break
            module._set_input_buffer(incremental_state, input_buffer)

        prune(self.self_attn)

    def get_steps(self, incremental_state):
        return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0)
