# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List, Optional

import torch
import torch.nn as nn
from .s4plusplus import s4plusplus
from .gated_cross_attention import GatedCrossAttention
from .normalized_feedforward_network import NormalizedFeedForwardNetwork
from torch import Tensor


class S4plusplusEncoderLayer(nn.Module):
    """Encoder layer block.

    Args:
        args (argparse.Namespace): parsed command-line arguments
    """

    def __init__(self, args):
        super().__init__()
        self.embed_dim = args.encoder_embed_dim
        self.s4plusplus_layer = self.build_s4plusplus_layer(self.embed_dim, args)
        if args.encoder_ffn_embed_dim is not None and args.encoder_ffn_embed_dim > 0:
            self.nffn = self.build_nffn_layer(self.embed_dim, args)
        else:
            self.nffn = None

    def build_s4plusplus_layer(self, embed_dim, args):
        return s4plusplus(
            embed_dim=embed_dim,
            qdim=64,
            vdim=getattr(args, 'encoder_ffn_embed_dim', 256),
            ndim=24,
            dropout=args.dropout,
            attention_dropout=args.attention_dropout,
            hidden_dropout=args.dropout,
            chunk_size=getattr(args, 'chunk_size', -1),
            truncation=getattr(args, 'truncation_length', None),
            rel_pos_bias=getattr(args, 'rel_pos_bias', 'simple'),
            max_positions=getattr(args, 'max_target_positions', 1024),
            activation=getattr(args, 'activation_fn', 'silu'),
            attention_activation=getattr(args, 'attention_activation_fn', 'softmax'),
            bidirectional=False,
            norm_type=getattr(args, 'normalization_type', 'layernorm'),
            prenorm=getattr(args, 'normalize_before', True),
            feature_dropout=getattr(args, 'feature_dropout', False)
        )

    def build_nffn_layer(self, embed_dim, args):
        return NormalizedFeedForwardNetwork(
            embed_dim=embed_dim,
            ffn_hidden_dim=getattr(args, 'encoder_ffn_embed_dim', 160),
            dropout=args.dropout,
            hidden_dropout=getattr(args, 'activation_dropout', 0.0),
            activation=getattr(args, 'activation_fn', 'silu'),
            norm_type=getattr(args, 'normalization_type', 'layernorm'),
            prenorm=getattr(args, 'normalize_before', True),
            feature_dropout=getattr(args, 'feature_dropout', False),
        )

    def forward(self, x, encoder_padding_mask):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, seq_len)` where padding elements are indicated by ``1``.

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        x, _ = self.s4plusplus_layer(x, encoder_padding_mask)
        if self.nffn is not None:
            x = self.nffn(x)

        return x
    
    def upgrade_state_dict_named(self, state_dict, name):
        """
        Rename layer norm states from `...layer_norms.0.weight` to
        `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
        `...final_layer_norm.weight`
        """
        layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
        for old, new in layer_norm_map.items():
            for m in ("weight", "bias"):
                k = "{}.layer_norms.{}.{}".format(name, old, m)
                if k in state_dict:
                    state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
                    del state_dict[k]

class S4plusplusDecoderLayer(nn.Module):
    """Decoder layer block.

    Args:
        args (argparse.Namespace): parsed command-line arguments
    """

    def __init__(self, args, no_cross_attention=False):
        super().__init__()
        self.embed_dim = args.decoder_embed_dim
        self.s4plusplus_layer = self.build_s4plusplus_layer(self.embed_dim, args)
        self.cross_attn = None if no_cross_attention else self.build_cross_attn(self.embed_dim, args)
        if args.decoder_ffn_embed_dim is not None and args.decoder_ffn_embed_dim > 0:
            self.nffn = self.build_nffn_layer(self.embed_dim, args)
        else:
            self.nffn = None

        self.need_attn = False
        self.onnx_trace = False

    def build_s4plusplus_layer(self, embed_dim, args):
        return s4plusplus(
            embed_dim=embed_dim,
            qdim=64,
            vdim=getattr(args, 'decoder_ffn_embed_dim', 160),
            ndim=16,
            dropout=args.dropout,
            attention_dropout=args.attention_dropout,
            hidden_dropout=args.dropout,
            chunk_size=getattr(args, 'chunk_size', -1),
            truncation=getattr(args, 'truncation_length', None),
            rel_pos_bias=getattr(args, 'rel_pos_bias', 'simple'),
            max_positions=getattr(args, 'max_target_positions', 1024),
            activation=getattr(args, 'activation_fn', 'silu'),
            attention_activation=getattr(args, 'attention_activation_fn', 'softmax'),
            bidirectional=False,
            norm_type=getattr(args, 'normalization_type', 'layernorm'),
            prenorm=getattr(args, 'normalize_before', True),
            feature_dropout=getattr(args, 'feature_dropout', False)
        )

    def build_cross_attn(self, embed_dim, args):
        return GatedCrossAttention(
            embed_dim=embed_dim,
            zdim=64,
            ndim=16,
            dropout=args.dropout,
            attention_dropout=args.attention_dropout,
            hidden_dropout=args.dropout,
            activation=getattr(args, 'activation_fn', 'silu'),
            attention_activation=getattr(args, 'attention_activation_fn', 'softmax'),
            norm_type=getattr(args, 'normalization_type', 'layernorm'),
            prenorm=getattr(args, 'normalize_before', True),
            feature_dropout=getattr(args, 'feature_dropout', False),
            rel_pos_bias=getattr(args, 'rel_pos_bias', 'simple'),
            max_positions=getattr(args, 'max_target_positions', 1024),
        )

    def build_nffn_layer(self, embed_dim, args):
        return NormalizedFeedForwardNetwork(
            embed_dim=embed_dim,
            ffn_hidden_dim=getattr(args, 'decoder_ffn_embed_dim', 160),
            dropout=args.dropout,
            hidden_dropout=getattr(args, 'activation_dropout', 0.0),
            activation=getattr(args, 'activation_fn', 'silu'),
            norm_type=getattr(args, 'normalization_type', 'layernorm'),
            prenorm=getattr(args, 'normalize_before', True),
            feature_dropout=getattr(args, 'feature_dropout', False),
        )

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def forward(
        self,
        x,
        encoder_out: Optional[torch.Tensor] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        decoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
    ):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_out (Tensor): encoder out for cross attention `(src_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``.
            incremental_state: dictionary for caching incremental states.
            attn_mask (Tensor): attention mask for autoregressive decoding.
            decoder_padding_mask: padding mask for target sequence.
            need_attn (bool, optional): return attention weights.

        Returns:
            encoded output of shape `(seq_len, batch, embed_dim)`
        """
        if need_head_weights:
            need_attn = True

        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.s4plusplus_layer._set_input_buffer(incremental_state, saved_state)
        _self_attn_input_buffer = self.s4plusplus_layer._get_input_buffer(incremental_state)
        if self.cross_attn and not (
            incremental_state is not None
            and _self_attn_input_buffer is not None
            and "prev_key" in _self_attn_input_buffer
        ):
            if self_attn_mask is not None:
                assert encoder_out is not None
                self_attn_mask = torch.cat(
                    (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
                )
            if self_attn_padding_mask is not None:
                if encoder_padding_mask is None:
                    assert encoder_out is not None
                    encoder_padding_mask = self_attn_padding_mask.new_zeros(
                        encoder_out.size(1), encoder_out.size(0)
                    )
                self_attn_padding_mask = torch.cat(
                    (encoder_padding_mask, self_attn_padding_mask), dim=1
                )
            assert encoder_out is not None
            y = torch.cat((encoder_out, x), dim=0)
        else:
            y = x
            
        x, attn = self.s4plusplus_layer(x=x, padding_mask=self_attn_padding_mask,
                                  incremental_state=incremental_state,
                                  need_weights=False, attn_mask=self_attn_mask)

        if self.cross_attn is not None:
            x, attn = self.cross_attn(query=x, key=encoder_out, value=encoder_out,
                                      key_padding_mask=encoder_padding_mask,
                                      incremental_state=incremental_state,
                                      static_kv=True, need_weights=need_attn)

        if self.nffn is not None:
            x = self.nffn(x)

        return x, attn, None

    def make_generation_fast_(self, need_attn: bool = False, **kwargs):
        self.need_attn = need_attn
