"""chobel Transformer

Revised from pytorch transformer implementation

    Author: Joey zhu
    Email: thewindblowing2you@gmail.com
    Date: 2024/4/24
    
"""
import torch
import copy
import math
from torch.nn import functional as F
from torch.nn.modules.module import Module
from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.container import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm

class PositionalEncoding(Module):
    """
    For positional encoding in transformer.

    """
    def __init__(self, d_model,pos_enc_start=0, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        self.d_model = d_model

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(pos_enc_start, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.pepe = pe
        self.register_buffer('pe', pe)
        

    def get_segment_pos_enc(self,segment_label):
        """get positional encoding according to the postional given

        Args:
            segment_label (tensor): positional numbers [seq_len, bsz]

        Returns:
            tensor: [seq_len, bsz,hid_dim]
        """
        # segment_label shape [seq_len, bsz]
        seq_len = segment_label.shape[0]
        bsz = segment_label.shape[1]
        segment_label = copy.deepcopy(segment_label)
        
        pos_enc = self.pe.squeeze()[segment_label]
        del segment_label
        return pos_enc

    def forward(self, x, pos_enc=None):
        if not pos_enc == None:
            x = x + pos_enc
        else:
            x = x + self.pe[:x.size(0), :]
        
        return self.dropout(x)
    
class Transformer(Module):
    """
    Transformer architecture

    """

    def __init__(self, combination, d_model=512, nhead=8, num_encoder_layers=6,
                 xor_pattern=[0,1]*3, dim_feedforward=2048, dropout=0.1,
                 activation="relu", custom_encoder=None, custom_decoder=None):
        super(Transformer, self).__init__()
        # 1. initialize the params
        self.d_model = d_model
        self.nhead = nhead

        self.pos_enc = PositionalEncoding(self.d_model,dropout=dropout)
        # 2. set the encoder
        if custom_encoder is not None:
            self.encoder = custom_encoder
        else:
            encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
            encoder_norm = LayerNorm(d_model)
            self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        # 3. set the decoder
        if custom_decoder is not None:
            self.decoder = custom_decoder
        else:
            decoder_layer_0 = TransformerDecoderLayer_CrossOnly(d_model, nhead, dim_feedforward, dropout, activation)
            decoder_layer_1 = TransformerDecoderLayer_parallel(combination, d_model, nhead, dim_feedforward, dropout, activation)
            # this is the serilized impl(cross+self)
            decoder_layer_2 = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
            
            decoder_norm = LayerNorm(d_model)
            
            decoder_patten = [1] * len(xor_pattern)

            self.decoder = TransformerDecoder_Hybrid([decoder_layer_0,decoder_layer_1,decoder_layer_2], decoder_patten,xor_pattern, decoder_norm)
        self._reset_parameters()

        

    def forward(self, src, tgt, src_mask=None, tgt_mask=None,
                memory_mask=None, src_key_padding_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None,tgt_label=None):


        if src.size(1) != tgt.size(1):
            raise RuntimeError("the batch number of src and tgt must be equal")

        if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
            raise RuntimeError("the feature number of src and tgt must be equal to d_model")
        att_cross_pos_enc = self.pos_enc(tgt)
        # att_cross_pos_enc = self.pos_enc.get_segment_pos_enc(tgt_label)
        # increasing theme mask
        # att_cross_pos_enc = (tgt_label)
        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask,
                              tgt_label=tgt_label,att_cross_pos_enc=att_cross_pos_enc)

        return output

    def generate_square_subsequent_mask(self, sz):
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        # 下三角全为1的矩阵
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask


    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""

        for p in self.parameters():
            if p.dim() > 1:
                # torch.nn.init.kaiming_uniform_(p)
                # torch.nn.init.normal_(p,mean=0,std=0.01)
                xavier_uniform_(p)


class TransformerEncoder(Module):
    r"""TransformerEncoder is a stack of N encoder layers

    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = transformer_encoder(src)
    """

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        print("Encoder layers #{}".format(self.num_layers))

    def forward(self, src, mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequnce to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = src

        for i in range(self.num_layers):
            output = self.layers[i](output, src_mask=mask,
                                    src_key_padding_mask=src_key_padding_mask)

        if self.norm:
            output = self.norm(output)

        return output


class TransformerDecoder(Module):
    r"""TransformerDecoder is a stack of N decoder layers

    Args:
        decoder_layer: an instance of the TransformerDecoderLayer() class (required).
        num_layers: the number of sub-decoder-layers in the decoder (required).
        norm: the layer normalization component (optional).

    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = transformer_decoder(tgt, memory)
    """

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt, memory, tgt_mask=None,
                memory_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None):
        r"""Pass the inputs (and mask) through the decoder layer in turn.

        Args:
            tgt: the sequence to the decoder (required).
            memory: the sequnce from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = tgt

        for i in range(self.num_layers):
            output = self.layers[i](output, memory, tgt_mask=tgt_mask,
                                    memory_mask=memory_mask,
                                    tgt_key_padding_mask=tgt_key_padding_mask,
                                    memory_key_padding_mask=memory_key_padding_mask)

        if self.norm:
            output = self.norm(output)

        return output

class TransformerDecoder_Hybrid(Module):
    r"""TransformerDecoder is a stack of N decoder layers

    Args:
        decoder_layer: an instance of the TransformerDecoderLayer() class (required).
        num_layers: the number of sub-decoder-layers in the decoder (required).
        norm: the layer normalization component (optional).

    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = transformer_decoder(tgt, memory)
    """

    def __init__(self, decoder_layer,decoder_layer_pattern,xor_pattern, norm=None):
        super(TransformerDecoder_Hybrid, self).__init__()

        self.layers = ModuleList([ copy.deepcopy(decoder_layer[x]) for x in decoder_layer_pattern]) # 000 111
        self.xor_pattern = xor_pattern
        self.num_layers = len(decoder_layer_pattern)
        self.norm = norm
        print("Decoder layers #{}".format(self.num_layers))
        

    def forward(self, tgt, memory, tgt_mask=None,
                memory_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None,tgt_label=None,att_cross_pos_enc=None):
        r"""Pass the inputs (and mask) through the decoder layer in turn.

        Args:
            tgt: the sequence to the decoder (required).
            memory: the sequnce from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = tgt

        for i in range(self.num_layers):
            if self.xor_pattern[i]:
                # require xor gate
                output = self.layers[i](output, memory, 
                                        memory_mask=memory_mask,
                                        tgt_key_padding_mask=tgt_key_padding_mask,
                                        memory_key_padding_mask=memory_key_padding_mask,
                                        tgt_label=tgt_label,mutual_exclusive=True,att_cross_pos_enc=att_cross_pos_enc)
            else:
                # no require xor gate
                output = self.layers[i](output, memory, 
                                        memory_mask=memory_mask,
                                        tgt_key_padding_mask=tgt_key_padding_mask,
                                        memory_key_padding_mask=memory_key_padding_mask,
                                        tgt_label=tgt_label,att_cross_pos_enc=att_cross_pos_enc)

        if self.norm:
            output = self.norm(output)

        return output

class TransformerEncoderLayer(Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512) batch_size, seq_len
        >>> out = encoder_layer(src)
    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequnce to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        if hasattr(self, "activation"):
            src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        else:  # for backward compatibility
            src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class TransformerDecoderLayer(Module):
    r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
    This standard decoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
        >>> memory = torch.rand(10, 32, 512)
        >>> tgt = torch.rand(20, 32, 512)
        >>> out = decoder_layer(tgt, memory)
    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        self.dropout3 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        r"""Pass the inputs (and mask) through the decoder layer.
        this is the serilized impl

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequnce from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        if hasattr(self, "activation"):
            tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        else:  # for backward compatibility
            tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

class TransformerDecoderLayer_CrossOnly(Module):
    """The Cross only Attention Impl 

    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerDecoderLayer_CrossOnly, self).__init__()
        # self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        # self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        # self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        self.dropout3 = Dropout(dropout)

        # for cross attention visualization
        self.cross_att_matrix = []

        self.pos_encoding = PositionalEncoding(d_model,dropout=dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None,tgt_label=None,att_cross_pos_enc=None):
        r"""Pass the inputs (and mask) through the decoder layer.

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequnce from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        # tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
        #                       key_padding_mask=tgt_key_padding_mask)[0]
        # tgt = tgt + self.dropout1(tgt2)
        # tgt = self.norm1(tgt)
        
        tgt2 = self.pos_encoding(tgt2,att_cross_pos_enc) # this is the chord aligement part
        tgt2, _, self.cross_att_matrix = self.multihead_attn(tgt2, memory, memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)

        # gating on cross ouput
        if not tgt_label == None:
            tgt2 = (tgt_label.unsqueeze(2).expand(-1,-1,tgt2.shape[2])) * tgt2

        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        if hasattr(self, "activation"):
            tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        else:  # for backward compatibility
            tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

class TransformerDecoderLayer_parallel(Module):
    """Transformer with parallel cross- anr self- attention together with xor gate

    """

    def __init__(self, combination, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerDecoderLayer_parallel, self).__init__()

        
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
            
        self.linear1 = Linear(d_model, dim_feedforward)
        self.linear2 = Linear(dim_feedforward, d_model)
        self.norm3 = LayerNorm(d_model)
        self.norm1 = LayerNorm(d_model)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.dropout = Dropout(dropout)
        self.nhead = nhead
        # self.norm2 = LayerNorm(d_model)
        
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        self.dropout3 = Dropout(dropout)
        self.combination = combination
        # for cross attention visualization
        self.cross_att_matrix = []

        self.pos_encoding = PositionalEncoding(d_model,dropout=dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, tgt, memory, label=None,tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None,tgt_label=None,
                mutual_exclusive=False,att_cross_pos_enc=None):
        r"""Pass the inputs (and mask) through the decoder layer.

        Args:
            tgt: the sequence to the decoder layer (required).
            memory: the sequnce from the last layer of the encoder (required).
            tgt_mask: the mask for the tgt sequence (optional).
            label: the label for the decoder layer (optional).
            memory_mask: the mask for the memory sequence (optional).
            tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
            memory_key_padding_mask: the mask for the memory keys per batch (optional).
            combination: the combination method of the self-attention and cross-attention.
            provide: 1. avg(simply avarage the data) 2. weighted-avg(has a self-leanring weighted param)
            3. full-connect(network)

        Shape:
            see the docs in Transformer class.
        """
        
        tgt2_self = self.pos_encoding(tgt) # self-attention position encoding
        # 这里需要根据label来决定是否使用label
        # if label is not None and self.use_label:
        #     label_tensor = label.unsqueeze(1)  # 添加一个维度
        #     label_tensor = self.label_projection(label_tensor) 
        #     label_matrix = label_tensor.unsqueeze(0).expand(tgt2_self.shape[0], -1, -1)  # 扩展为 (batch_size, m, d_model)
        #     tgt2_self = torch.cat((tgt2_self, label_matrix), dim=-1)  # 拼接为 (batch_size, m, nhead + d_model)
        tgt2_self = self.self_attn(tgt2_self, tgt2_self, tgt2_self, attn_mask=tgt_mask,
                            key_padding_mask=tgt_key_padding_mask)[0]
        tgt2_self = self.dropout1(tgt2_self)

        tgt2_cross = self.pos_encoding(tgt,att_cross_pos_enc)
        tgt2_cross, _  = self.multihead_attn(tgt2_cross, memory, memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)
        tgt2_cross = self.dropout2(tgt2_cross)
        # if label is not None and self.use_label:
        #     padding = torch.zeros(tgt2_cross.shape[0], tgt2_cross.shape[1], self.nhead, device=tgt2_cross.device)
        #     tgt2_cross = torch.cat((tgt2_cross, padding), dim=-1) 
        # before here we have the output of self-attention and cross-attention

        # generate and use the mask for cross-attention output
        gate_msk = ((tgt_label > 0).unsqueeze(2).expand(-1,-1,tgt2_cross.shape[2]))
        assert (gate_msk.dtype == torch.bool)
        
        gate_msk = (tgt_label.unsqueeze(2).expand(-1,-1,tgt2_cross.shape[2])).bool()
        
        if not tgt_label == None:
            tgt2_cross = gate_msk * tgt2_cross
            # todo: 我的场景应该一直不需要mutual_exclusive=True
            if mutual_exclusive:
                tgt2_self = ~gate_msk * tgt2_self

        if self.combination == 'avg':
            tgt3 = (tgt2_self + tgt2_cross ) / 2.0
        elif self.combination == 'weighted-avg':
            fusion_module = AttentionFusionWithLearnedWeights(tgt2_self.shape[-1], 128)
            fusion_module = fusion_module.to(tgt.device)
            tgt3 = fusion_module(tgt2_self, tgt2_cross)
        elif self.combination == 'full-connect':
            fusion_module = AttentionFusionModule(tgt2_self.shape[-1], 256, tgt2_self.shape[-1])
            fusion_module = fusion_module.to(tgt.device)
            tgt3 = fusion_module(tgt2_self, tgt2_cross)

        # if label is not None and self.use_label:
        #     padding = torch.zeros(tgt2_cross.shape[0], tgt2_cross.shape[1], self.nhead, device=tgt2_cross.device)
        #     tgt = torch.cat((tgt, padding), dim=-1) 
        tgt3 = tgt + tgt3
        # here we need to restore the dimension to 512
        # tgt3 = self.linear_restore((self.dropout(F.relu(tgt))))
        tgt = self.norm1(tgt3)
        if hasattr(self, "activation"):
            tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        else:  # for backward compatibility
            tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt


class AttentionFusionModule(Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(AttentionFusionModule, self).__init__()
        # 输入为self-attention和cross-attention拼接后的维度
        self.linear1 = Linear(input_dim * 2, hidden_dim)  # 第一层线性变换
        self.linear2 = Linear(hidden_dim, output_dim)  # 第二层线性变换
        self.dropout = Dropout(p=0.1)  # Dropout层，防止过拟合
        self.norm = LayerNorm(output_dim)  # LayerNorm，保证输出稳定

    def forward(self, tgt2_self, tgt2_cross):
        # 拼接 tgt2_self 和 tgt2_cross
        combined_input = torch.cat((tgt2_self, tgt2_cross), dim=-1)
        
        # 通过全连接层进行处理
        x = F.relu(self.linear1(combined_input))  # ReLU 激活
        x = self.dropout(x)  # Dropout
        x = self.linear2(x)  # 第二层线性变换
        
        # LayerNorm 归一化
        x = self.norm(x)
        
        return x


class AttentionFusionWithLearnedWeights(Module):
    def __init__(self, input_dim, hidden_dim):
        super(AttentionFusionWithLearnedWeights, self).__init__()
        # 全连接层，用于学习两个权重
        self.fc1 = Linear(input_dim, hidden_dim)
        self.fc2 = Linear(hidden_dim, 1)  # 输出一个标量的权重
        self.dropout = Dropout(p=0.1)  # Dropout层，防止过拟合
        self.norm = LayerNorm(input_dim)  # LayerNorm，保证输出稳定

    def forward(self, tgt2_self, tgt2_cross):
        # 对 tgt2_self 和 tgt2_cross 分别生成权重
        weight_self = self._generate_weight(tgt2_self)
        weight_cross = self._generate_weight(tgt2_cross)
        
        # 用权重对两个输入进行加权求和
        tgt3 = weight_self * tgt2_self + weight_cross * tgt2_cross
        
        # 可以选择添加 Dropout 和归一化
        tgt3 = self.norm(tgt3)
        
        return tgt3

    def _generate_weight(self, x):
        #TODO: here have to make sure the structure of x
        # 输入 x 的 shape 是 (batch_size, seq_len, input_dim)
        x = x.mean(dim=1)  # 对 seq_len 维度求平均，得到每个样本的一个全局特征
        x = F.relu(self.fc1(x))  # 第一层线性变换和 ReLU 激活
        weight = torch.sigmoid(self.fc2(x))  # 输出一个标量权重，sigmoid 保证权重在 [0, 1] 范围内
        return weight.unsqueeze(1).unsqueeze(2)  # 为了广播，添加 seq_len 和 input_dim 维度

def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    else:
        raise RuntimeError("activation should be relu/gelu, not %s." % activation)
