from dataclasses import dataclass
from typing import Union
from packaging import version

import torch

from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.utils import make_viewless_tensor
from megatron.core.transformer.transformer_layer import (
    TransformerLayerSubmodules,
    TransformerLayer,
)
from megatron.core import __version__

from gpatch.core.transformer.transformer_config import Gemma3TransformerConfig


@dataclass
class Gemma3TransformerLayerSubmodules(TransformerLayerSubmodules):
    post_attention_layernorm: Union[ModuleSpec, type] = IdentityOp
    post_feedforward_layernorm: Union[ModuleSpec, type] = IdentityOp


class Gemma3TransformerLayer(TransformerLayer):

    def __init__(
        self,
        config: Gemma3TransformerConfig,
        submodules: Gemma3TransformerLayerSubmodules,
        layer_number: int = 1,
        hidden_dropout: float = None,
    ):
        super(Gemma3TransformerLayer, self).__init__(
            config=config,
            submodules=submodules,
            layer_number=layer_number,
            hidden_dropout=hidden_dropout,
        )

        self.post_attention_layernorm = build_module(
            submodules.post_attention_layernorm,
            config=self.config,
            hidden_size=self.config.hidden_size,
            eps=self.config.layernorm_epsilon,
        )
        self.post_feedforward_layernorm = build_module(
            submodules.post_feedforward_layernorm,
            config=self.config,
            hidden_size=self.config.hidden_size,
            eps=self.config.layernorm_epsilon,
        )
        self.is_sliding = bool(self.layer_number % config.sliding_window_pattern)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        context=None,
        context_mask=None,
        rotary_pos_emb=None,
        rotary_pos_cos=None,
        rotary_pos_sin=None,
        attention_bias=None,
        inference_context=None,
        inference_params=None,
        packed_seq_params=None,
        sequence_len_offset=None,
    ):
                              
        if isinstance(rotary_pos_emb, tuple) and isinstance(attention_mask, tuple):
            if self.is_sliding:
                                                 
                rotary_pos_emb = rotary_pos_emb[1]
                attention_mask = attention_mask[1]
            else:
                                                  
                rotary_pos_emb = rotary_pos_emb[0]
                attention_mask = attention_mask[0]
        residual = hidden_states

        extra_kwargs = {}
        if version.parse(__version__) >= version.parse('0.12.0'):
            extra_kwargs["inference_context"] = inference_context
        else:
            extra_kwargs["inference_params"] = inference_params

                                   
        input_layernorm_output = self.input_layernorm(hidden_states)

                         
        hidden_states, hidden_states_bias = self.self_attention(
            input_layernorm_output,
            attention_mask=attention_mask,
            rotary_pos_emb=rotary_pos_emb,
            rotary_pos_cos=rotary_pos_cos,
            rotary_pos_sin=rotary_pos_sin,
            attention_bias=attention_bias,
            packed_seq_params=packed_seq_params,
            sequence_len_offset=sequence_len_offset,
            **extra_kwargs,
        )

        if hidden_states_bias is not None:
            hidden_states = hidden_states + hidden_states_bias
        else:
            hidden_states = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

                              
        residual = hidden_states

                                                  
        pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states)

                          
        attention_output_with_bias = self.cross_attention(
            pre_cross_attn_layernorm_output,
            attention_mask=context_mask,
            key_value_states=context,
            **extra_kwargs,
        )

        if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
            context = attention_output_with_bias["context"]

                                                                    
                                                                           
        with self.bias_dropout_add_exec_handler():
            hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
                attention_output_with_bias, residual, self.hidden_dropout)

                              
        residual = hidden_states

                                                       
        pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)

              
        hidden_states, hidden_states_bias = self.mlp(pre_mlp_layernorm_output)
        if hidden_states_bias is not None:
            hidden_states = hidden_states + hidden_states_bias
        else:
            hidden_states = hidden_states
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states

                                                                  
                                                                        
                                                                         
                                                                  
                                                                      
                        
        output = make_viewless_tensor(inp=hidden_states,
                                      requires_grad=hidden_states.requires_grad,
                                      keep_graph=True)

                                                           
        if self.config.external_cuda_graph and self.training:
            return output
        return output, context


class Qwen2p5VitTransformerLayer(TransformerLayer):

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        context=None,
        context_mask=None,
        rotary_pos_emb=None,
        rotary_pos_cos=None,
        rotary_pos_sin=None,
        attention_bias=None,
        inference_context=None,
        inference_params=None,
        packed_seq_params=None,
        sequence_len_offset=None,
    ):
                                  
        if self.layer_number - 1 in self.config.fullatt_block_indexes:
            if attention_mask is not None:
                attention_mask = attention_mask[0]
            packed_seq_params = packed_seq_params[0]
        else:
            if attention_mask is not None:
                attention_mask = attention_mask[1]
            packed_seq_params = packed_seq_params[1]

        extra_kwargs = {}
        if version.parse(__version__) >= version.parse('0.12.0'):
            extra_kwargs["inference_context"] = inference_context

        return super(Qwen2p5VitTransformerLayer, self).forward(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            context=context,
            context_mask=context_mask,
            rotary_pos_emb=rotary_pos_emb,
            rotary_pos_cos=rotary_pos_cos,
            rotary_pos_sin=rotary_pos_sin,
            attention_bias=attention_bias,
            packed_seq_params=packed_seq_params,
            sequence_len_offset=sequence_len_offset,
            inference_params=inference_params,
            **extra_kwargs,
        )
