              
                                                      
                       

from packaging.version import Version
try:
    from xformers.ops.fmha.attn_bias import AttentionBias
except:
    class AttentionBias:
        pass

from megatron.core import package_info
from megatron.core.transformer.attention import SelfAttention

from gpatch.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
from gpatch.core.device_type import is_wxacc2


class PackedSelfAttention(SelfAttention):

    def forward(
        self,
        hidden_states,
        attention_mask,
        key_value_states=None,
        inference_context=None,
        rotary_pos_emb=None,
        rotary_pos_cos=None,
        rotary_pos_sin=None,
        attention_bias=None,
        packed_seq_params=None,
        sequence_len_offset=None,
        *,
        inference_params=None,
    ):
        """
        Perform a forward pass through the attention module.
        """
        mcore_version = Version(package_info.__version__)
                                                            
        if mcore_version >= Version("0.12.1"):
            pass

                                   
        if self.config.flash_decode:
            rotary_pos_emb = None
        else:
            assert rotary_pos_cos is None and rotary_pos_sin is None

                                                                                     
        if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
            rotary_pos_emb = (rotary_pos_emb, ) * 2

                               
                               
                               
                                                                               
                             
        query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)

                                                             
                                                             
                                                             

                                                                                                  
                                                                                                    
        if (self.config.flash_decode and inference_params is not None
                and inference_params.decode_mode):
            assert self.layer_number in inference_params.key_value_memory_dict
            assert inference_params.sequence_len_offset is not None
            inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
                self.layer_number]
            output = self.flash_decode(
                sequence_len_offset=sequence_len_offset,
                query_layer=query,
                key_layer=key,
                value_layer=value,
                inference_key_memory=inference_key_memory,
                inference_value_memory=inference_value_memory,
                rotary_cos=rotary_pos_cos,
                rotary_sin=rotary_pos_sin,
            )
            out = output.transpose(0, 1).contiguous()
            context_layer = out.view(out.size(0), out.size(1), -1)
            output, bias = self.linear_proj(context_layer)
            return output, bias

                                                          
        query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
            inference_params,
            query,
            key,
            value,
            rotary_pos_emb,
            rotary_pos_cos,
            rotary_pos_sin,
            sequence_len_offset,
        )

        if packed_seq_params is not None:
            query = query.squeeze(1)
            key = key.squeeze(1)
            value = value.squeeze(1)

                                                          
                                                          
                                                          
        if rotary_pos_emb is not None and not self.config.flash_decode:
            q_pos_emb, k_pos_emb = rotary_pos_emb

            if packed_seq_params is not None:
                if packed_seq_params.cu_seqlens_q_padded is not None:
                    cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded
                else:
                    cu_seqlens_q = packed_seq_params.cu_seqlens_q
                if packed_seq_params.cu_seqlens_kv_padded is not None:
                    cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded
                else:
                    cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
            else:
                cu_seqlens_q = cu_seqlens_kv = None
            query = apply_rotary_pos_emb(query,
                                         q_pos_emb,
                                         config=self.config,
                                         cu_seqlens=cu_seqlens_q)
            key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)

                                                                           
                                            
                                                                        
                                                                        

                                            
                                    
                                            

                                                              
        if (packed_seq_params is not None and isinstance(attention_mask, AttentionBias)) or is_wxacc2():
            query = query.unsqueeze(1)
            key = key.unsqueeze(1)
            value = value.unsqueeze(1)

        if self.checkpoint_core_attention and self.training:
            core_attn_out = self._checkpointed_attention_forward(
                query,
                key,
                value,
                attention_mask,
                attn_mask_type=attn_mask_type,
                attention_bias=attention_bias,
                packed_seq_params=packed_seq_params,
            )
        else:
            core_attn_out = self.core_attention(
                query,
                key,
                value,
                attention_mask,
                attn_mask_type=attn_mask_type,
                attention_bias=attention_bias,
                packed_seq_params=packed_seq_params,
            )

        if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
                                                           
                                              
                                             
                                                                     
            core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)

                           
                            
                           

        output, bias = self.linear_proj(core_attn_out)

        return output, bias
