# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Dict, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import Parameter

from .utils import softmax, laplace, relu2, get_activation_fn
from fairseq.incremental_decoding_utils import with_incremental_state
from .fairseq_dropout import FairseqDropout, FairseqFeatureDropout
from .relative_positional_bias import SimpleRelativePositionalBias, RotaryRelativePositionalBias
from .sequence_norm import SequenceNorm
from ..s4 import S4

@with_incremental_state
class s4plusplus(nn.Module):
    """Exponential Moving Average Gated Attention.

    See "" for more details.
    """

    def __init__(
        self,
        embed_dim,
        qdim,
        vdim,
        ndim,
        dropout=0.0,
        attention_dropout=0.0,
        hidden_dropout=0.0,
        activation='silu',
        attention_activation='softmax',
        bidirectional=False,
        chunk_size=-1,
        truncation=None,
        norm_type='layernorm',
        prenorm=True,
        norm_affine=True,
        feature_dropout=False,
        rel_pos_bias='rotary',
        max_positions=512,
        export=False,
    ):
        super().__init__()

        self.embed_dim = embed_dim
        self.vdim = vdim
        self.qdim = qdim
        self.ndim = ndim
        self.activation = get_activation_fn(activation=activation)
        self.attention_activation = attention_activation
        self.scaling = self.qdim ** -0.5 if attention_activation == 'softmax' else None

        dropout_module = FairseqFeatureDropout if feature_dropout else FairseqDropout
        self.dropout = dropout_module(dropout, module_name=self.__class__.__name__)
        self.hidden_dropout = dropout_module(hidden_dropout, module_name=self.__class__.__name__)
        # Attention dropout is standard dropout
        self.attention_dropout = FairseqDropout(attention_dropout, module_name=self.__class__.__name__)

        self.chunk_size = chunk_size
        self.prenorm = prenorm
        self.norm = SequenceNorm(norm_type, embed_dim, affine=norm_affine, export=export)

        self.s4 = S4(embed_dim, d_state=256, channels=int(ndim/4), bidirectional=bidirectional, activation=activation, l_max=max_positions, transposed=False, dropout=dropout)
        self.arconv = nn.Conv1d(embed_dim, embed_dim, 3, 1, 0)
        self.q_proj = nn.Linear(embed_dim, qdim)
        self.k_proj = nn.Linear(embed_dim, qdim)
        self.v_proj = nn.Linear(embed_dim, vdim)
        
        self.s_proj = nn.Linear(embed_dim, embed_dim)
        self.rho_proj = nn.Linear(embed_dim, vdim)
        self.h_proj = nn.Linear(vdim, embed_dim)
        self.psi_proj = nn.Linear(embed_dim, embed_dim, bias= False)
 

        self.max_positions = max_positions
        max_positions = max_positions if chunk_size < 0 else chunk_size
        if rel_pos_bias == 'simple':
            self.rel_pos_bias = SimpleRelativePositionalBias(max_positions)
        elif rel_pos_bias == 'rotary':
            self.rel_pos_bias = RotaryRelativePositionalBias(qdim, max_positions)
        else:
            raise ValueError('unknown relative position bias: {}'.format(rel_pos_bias))

        self.reset_parameters()

        self.onnx_trace = False
        self.tpu = False

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def prepare_for_tpu_(self, **kwargs):
        self.tpu = True

    def reset_parameters(self):
        std = 0.02
        nn.init.normal_(self.v_proj.weight, mean=0.0, std=std)
        nn.init.constant_(self.v_proj.bias, 0.0)

        nn.init.normal_(self.q_proj.weight, mean=0.0, std=std)
        nn.init.constant_(self.q_proj.bias, 0.0)

        nn.init.normal_(self.k_proj.weight, mean=0.0, std=std)
        nn.init.constant_(self.k_proj.bias, 0.0)

        nn.init.normal_(self.s_proj.weight, mean=0.0, std=std)
        nn.init.constant_(self.s_proj.bias, 0.0)

        nn.init.normal_(self.rho_proj.weight, mean=0.0, std=std)
        nn.init.constant_(self.rho_proj.bias, 0.0)

        nn.init.normal_(self.h_proj.weight, mean=0.0, std=std)
        nn.init.constant_(self.h_proj.bias, 0.0)

        nn.init.normal_(self.psi_proj.weight, mean=0.0, std=std)
        #nn.init.constant_(self.psi_proj.bias, 0.0)

    def element_attention(self, q, k, padding_mask, attn_mask, before_attn_fn):
        slen = k.size(2)
        if padding_mask is not None:
            # B x K x C
            inverse_mask = 1.0 - padding_mask.type_as(q)
            # B x K x 1
            lengths = inverse_mask.sum(dim=-1, keepdim=True)
            # B x K x 1 x 1
            lengths = lengths.clamp(min=1.0).unsqueeze(-1)
        else:
            lengths = slen
            inverse_mask = None

        if attn_mask is not None:
            # C x 1
            lengths = attn_mask.sum(dim=-1, keepdim=True)

        # C x C
        bias = self.rel_pos_bias(slen)
        if slen != q.size(2):
            assert q.size(2) == 1
            # 1 x C
            bias = bias[-1:]

        # B x K x C x C
        qk = torch.matmul(q, k.transpose(2, 3)) / lengths + bias

        if before_attn_fn:
            return qk

        if self.attention_activation == 'relu2':
            attn_weights = relu2(qk).type_as(qk)
        elif self.attention_activation == 'laplace':
            attn_weights = laplace(qk).type_as(qk)
        else:
            raise ValueError('Unknown attention activation function: {}'.format(self.attention_activation))

        if inverse_mask is not None:
            attn_weights = attn_weights * inverse_mask.unsqueeze(2)

        if attn_mask is not None:
            attn_weights = attn_weights * attn_mask

        return attn_weights

    def softmax_attention(self, q, k, padding_mask, attn_mask, before_attn_fn):
        slen = k.size(2)
        # C x C
        bias = self.rel_pos_bias(slen)
        if slen != q.size(2):
            assert q.size(2) == 1
            # 1 x C
            bias = bias[-1:]

        # scaled attention
        q = q * self.scaling
        # B x K x C x C
        qk = torch.matmul(q, k.transpose(2, 3)) + bias

        if attn_mask is not None:
            qk = qk + attn_mask

        if padding_mask is not None:
            padding_mask_all = padding_mask.all(dim=-1, keepdim=True)
            padding_mask = torch.logical_and(padding_mask, ~padding_mask_all)
            qk = qk.masked_fill(padding_mask.unsqueeze(2).to(torch.bool), float('-inf'))

        if before_attn_fn:
            return qk

        #attn_weights = self.softmax_1(qk, dim=-1)
        attn_weights = softmax(qk, dim=-1)
        return attn_weights

    def softmax_1(self, x, dim):
        numerator = torch.exp(x)
        denominator = numerator.sum(dim=dim, keepdim=True) + 1  # 在分母上额外添加了一个1
        return numerator / denominator

    def forward(
        self,
        x,
        padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        need_weights: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_attn_fn: bool = False,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Input shape: Time x Batch x Channel

        Args:
            padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_attn_fn (bool, optional): return the raw attention
                weights and values before the attention softmax.
        """
        x = x.transpose(0, 1)
        seq_len, bsz, embed_dim = x.size()
        assert embed_dim == self.embed_dim

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
        else:
            saved_state = None
        residual = x.transpose(0, 1)

        if self.prenorm:
            residual = self.norm(residual)
            x = self.norm(x)
        # L x B x E
        padding_values = residual[:, 0, :].unsqueeze(1)
        padded_residual = torch.cat([padding_values.transpose(1,2), padding_values.transpose(1,2), residual.transpose(1,2)],dim=-1)
        adjust = self.arconv(padded_residual).transpose(1,2)
        x = torch.sigmoid(adjust).transpose(0, 1)*x
        q = self.activation(self.q_proj(residual.transpose(0, 1)))
        
        v = self.activation(self.v_proj(residual.transpose(0, 1)))

        # ——>B x L x D
        ssmx = self.s4(x.transpose(0,1))[0].transpose(0,1)
        ssmx = self.dropout(ssmx)
        k = self.activation(self.k_proj(ssmx))
        
        # L x B x D
        rho = self.activation(self.rho_proj(ssmx))
        psi = torch.sigmoid(self.psi_proj(ssmx))
        s = self.s_proj(ssmx)


        # L x B x D -> B x L x D
        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)

        if saved_state is not None:
            # assert self.chunk_size < 0 or q.size(1) <= self.chunk_size
            # saved states are stored with shape (bsz, seq_len, dim)
            if "prev_key" in saved_state:
                prev_key = saved_state["prev_key"]
                assert prev_key is not None
                assert k is not None
                k = torch.cat([prev_key, k], dim=1)
            if "prev_value" in saved_state:
                prev_value = saved_state["prev_value"]
                assert prev_value is not None
                assert v is not None
                v = torch.cat([prev_value, v], dim=1)
            prev_padding_mask: Optional[Tensor] = None
            if "prev_padding_mask" in saved_state:
                prev_padding_mask = saved_state["prev_padding_mask"]
            padding_mask = MovingAverageGatedAttention_ar._append_prev_padding_mask(
                padding_mask=padding_mask,
                prev_padding_mask=prev_padding_mask,
                batch_size=bsz,
                seq_len=k.size(1),
            )

            if self.chunk_size < 0:
                saved_state["prev_key"] = k
                saved_state["prev_value"] = v
                saved_state["prev_key_padding_mask"] = padding_mask
            else:
                curr_len = k.size(1) % self.chunk_size
                if curr_len == 0:
                    if "prev_key" in saved_state:
                        del saved_state["prev_key"]
                        del saved_state["prev_value"]
                        del saved_state["prev_key_padding_mask"]
                else:
                    saved_state["prev_key"] = k
                    saved_state["prev_value"] = v
                    saved_state["prev_key_padding_mask"] = padding_mask
            # In this branch incremental_state is never None
            assert incremental_state is not None
            self._set_input_buffer(incremental_state, saved_state)

        ctx_len = k.size(1)
        if self.chunk_size < 0:
            # B x L x S -> B x 1 x L x S
            q = q.unsqueeze(1)
            k = k.unsqueeze(1)
            v = v.unsqueeze(1)
            if padding_mask is not None:
                # B x L -> B x 1 x L
                padding_mask = padding_mask.unsqueeze(1)
        else:
            if seq_len < self.chunk_size:
                q = q.unsqueeze(1)
            else:
                # B x L x S -> B x K x C x S
                nc = seq_len // self.chunk_size
                q = q.reshape(bsz, nc, self.chunk_size, self.qdim)

            if ctx_len < self.chunk_size:
                k = k.unsqueeze(1)
                v = v.unsqueeze(1)
                if padding_mask is not None:
                    padding_mask = padding_mask.unsqueeze(1)
            else:
                # B x L x S -> B x K x C x S
                nc = ctx_len // self.chunk_size
                k = k.reshape(bsz, nc, self.chunk_size, self.qdim)
                v = v.reshape(bsz, nc, self.chunk_size, self.vdim)
                if padding_mask is not None:
                    # B x L -> B x K x C
                    padding_mask = padding_mask.view(bsz, nc, self.chunk_size)

        # This is part of a workaround to get around fork/join parallelism
        # not supporting Optional types.
        if padding_mask is not None and padding_mask.dim() == 0:
            padding_mask = None

        if self.attention_activation == 'softmax':
            attn_weights = self.softmax_attention(q, k, padding_mask, attn_mask, before_attn_fn)
        else:
            attn_weights = self.element_attention(q, k, padding_mask, attn_mask, before_attn_fn)

        if before_attn_fn:
            return attn_weights, v

        v = self.hidden_dropout(v, batch_first=True)
        kernel = self.attention_dropout(attn_weights)
        # B x K x C x E -> B x L x E -> L x B x E
        h = torch.matmul(kernel, v).view(bsz, seq_len, self.vdim).transpose(0, 1)
        # L x B x E -> L x B x D
        #h = self.activation(hx + self.h_proj(h * r))
        h = self.activation(self.h_proj(h * rho)+s)
        h = self.dropout(h)
        # L x B x D
        residual = residual.transpose(0, 1)
        out = torch.addcmul(residual, psi, h-residual)

        if not self.prenorm:
            out = self.norm(out)

        if need_weights:
            return out, attn_weights
        else:
            return out, None

    def _get_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]) -> Dict[str, Optional[Tensor]]:
        result = self.get_incremental_state(incremental_state, "attn_state")
        if result is not None:
            return result
        else:
            empty_result: Dict[str, Optional[Tensor]] = {}
            return empty_result

    def _set_input_buffer(self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], buffer: Dict[str, Optional[Tensor]]):
        return self.set_incremental_state(incremental_state, "attn_state", buffer)

    @torch.jit.export
    def reorder_incremental_state(
            self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
    ):
        """Reorder buffered internal state (for incremental generation)."""
        input_buffer = self._get_input_buffer(incremental_state)
        if input_buffer is not None:
            for k in input_buffer.keys():
                input_buffer_k = input_buffer[k]
                if input_buffer_k is not None:
                    input_buffer[k] = input_buffer_k.index_select(0, new_order)
            incremental_state = self._set_input_buffer(incremental_state, input_buffer)
        return incremental_state

    @staticmethod
    def _append_prev_padding_mask(
        padding_mask: Optional[Tensor],
        prev_padding_mask: Optional[Tensor],
        batch_size: int,
        seq_len: int,
    ) -> Optional[Tensor]:
        # saved key padding masks have shape (bsz, seq_len)
        if prev_padding_mask is not None and padding_mask is not None:
            new_padding_mask = torch.cat([prev_padding_mask, padding_mask], dim=1)
        # During incremental decoding, as the padding token enters and
        # leaves the frame, there will be a time when prev or current
        # is None
        elif prev_padding_mask is not None:
            filler = torch.zeros((batch_size, seq_len - prev_padding_mask.size(1)), device=prev_padding_mask.device)
            new_padding_mask = torch.cat([prev_padding_mask, filler.bool()], dim=1)
        elif padding_mask is not None:
            filler = torch.zeros((batch_size, seq_len - padding_mask.size(1)), device=padding_mask.device)
            new_padding_mask = torch.cat([filler.bool(), padding_mask], dim=1)
        else:
            new_padding_mask = prev_padding_mask
        return new_padding_mask

    def extra_repr(self) -> str:
        return 'edim={}, qdim={}, vdim={}, ndim={}, chunk={}, attn_act={}, prenorm={}'.format(self.embed_dim, self.qdim,
                                                                                  self.vdim, self.ndim, self.chunk_size,
                                                                                  self.attention_activation, self.prenorm)
