from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.attention import SelfAttentionSubmodules
from megatron.core.transformer.custom_layers.transformer_engine import (
    TEDotProductAttention,
    TELayerNormColumnParallelLinear,
    TERowParallelLinear,
    TEColumnParallelLinear,
    TENorm,
)
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add

from gpatch.core.transformer.transformer_layer import Qwen2p5VitTransformerLayer
from gpatch.core.transformer.attention import PackedSelfAttention
from gpatch.core.extensions.llama3_cp_memory_efficient_attention import MemoryEfficientAttention
from gpatch.core.tensor_parallel.lora import (
    TEColumnParallelLoRALinear,
    TERowParallelLoRALinear,
    TELayerNormColumnParallelLoRALinear,
    TELayerNormColumnParallelLoRAMergeLinear,
    TELayerNormColumnParallelLoRAQKVLinear,
    ColumnParallelLoRAQKVLinear,
    ColumnParallelLoRALinear,
    ColumnParallelLoRAMergeLinear,
    RowParallelLoRALinear,
)

from gpatch.core.device_type import is_wxacc1

from megatron.core.transformer.attention import SelfAttentionSubmodules
from megatron.core.transformer.attention import SelfAttention


def get_qwen2vl_vision_local_spec(is_qwen2p5=False) -> ModuleSpec:
    attn_mask_type = AttnMaskType.no_mask

    mlp = ModuleSpec(
        module=MLP,
        submodules=MLPSubmodules(
            linear_fc1=ColumnParallelLinear,
            linear_fc2=RowParallelLinear,
        ),
    )

    if is_qwen2p5:
        tf_layer = Qwen2p5VitTransformerLayer
        attention_module = PackedSelfAttention
    else:
        tf_layer = TransformerLayer
        attention_module = SelfAttention

    return ModuleSpec(
        module=tf_layer,
        submodules=TransformerLayerSubmodules(
            input_layernorm=TENorm,
            self_attention=ModuleSpec(
                module=attention_module,
                params={"attn_mask_type": attn_mask_type},
                submodules=SelfAttentionSubmodules(
                    linear_qkv=ColumnParallelLinear,
                    core_attention=DotProductAttention,
                    linear_proj=RowParallelLinear,
                    q_layernorm=IdentityOp,
                    k_layernorm=IdentityOp,
                ),
            ),
            self_attn_bda=get_bias_dropout_add,
            pre_mlp_layernorm=TENorm,
            mlp=mlp,
            mlp_bda=get_bias_dropout_add,
            sharded_state_dict_keys_map={
                'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
                'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
            },
        ),
    )

def get_qwen2vl_vision_with_transformer_engine_spec(cp_size, is_qwen2p5=False) -> ModuleSpec:
    attn_mask_type = AttnMaskType.no_mask                      

    mlp = ModuleSpec(
        module=MLP,
        submodules=MLPSubmodules(
            linear_fc1=TELayerNormColumnParallelLinear,
            linear_fc2=TERowParallelLinear,
        ),
    )
    if is_qwen2p5:
        tf_layer = Qwen2p5VitTransformerLayer
    else:
        tf_layer = TransformerLayer

    return ModuleSpec(
        module=tf_layer,
        submodules=TransformerLayerSubmodules(
            self_attention=ModuleSpec(
                module=PackedSelfAttention,
                params={"attn_mask_type": attn_mask_type},
                submodules=SelfAttentionSubmodules(
                    linear_qkv=TELayerNormColumnParallelLinear,
                                                              
                    core_attention=MemoryEfficientAttention,
                    linear_proj=TERowParallelLinear,
                    q_layernorm=IdentityOp,
                    k_layernorm=IdentityOp,
                ),
            ),
            self_attn_bda=get_bias_dropout_add,
            pre_mlp_layernorm=IdentityOp,
            mlp=mlp,
            mlp_bda=get_bias_dropout_add,
        ),
    )


                                                
def get_proj_mlp_module_spec(use_te: bool = True, add_norm: bool = True) -> ModuleSpec:
                                     
    if add_norm:
                                                                
        return ModuleSpec(
            module=MLP,
            submodules=MLPSubmodules(
                linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
                linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
            ),
        )
    else:
        return ModuleSpec(
            module=MLP,
            submodules=MLPSubmodules(
                linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
                linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
            ),
        )


def get_qwen2vl_vision_with_transformer_engine_spec_lora(
    cp_size,
    is_qwen2p5=False,
    gated_linear_unit: bool = True,
) -> ModuleSpec:
    attn_mask_type = AttnMaskType.no_mask                      

    if gated_linear_unit:
        linear_fc1 = TELayerNormColumnParallelLoRAMergeLinear
    else:
        linear_fc1 = TELayerNormColumnParallelLoRALinear

    mlp = ModuleSpec(
        module=MLP,
        submodules=MLPSubmodules(
            linear_fc1=linear_fc1,
            linear_fc2=TERowParallelLoRALinear,
        ),
    )
    if is_qwen2p5:
        tf_layer = Qwen2p5VitTransformerLayer
    else:
        tf_layer = TransformerLayer

    return ModuleSpec(
        module=tf_layer,
        submodules=TransformerLayerSubmodules(
            self_attention=ModuleSpec(
                module=PackedSelfAttention,
                params={"attn_mask_type": attn_mask_type},
                submodules=SelfAttentionSubmodules(
                    linear_qkv=TELayerNormColumnParallelLoRAQKVLinear,
                                                              
                    core_attention=MemoryEfficientAttention,
                    linear_proj=TERowParallelLoRALinear,
                    q_layernorm=IdentityOp,
                    k_layernorm=IdentityOp,
                ),
            ),
            self_attn_bda=get_bias_dropout_add,
            pre_mlp_layernorm=IdentityOp,
            mlp=mlp,
            mlp_bda=get_bias_dropout_add,
        ),
    )


def get_qwen2vl_vision_with_local_spec_lora(
    cp_size,
    is_qwen2p5=False,
    gated_linear_unit: bool = True,
) -> ModuleSpec:
    attn_mask_type = AttnMaskType.no_mask                      

    if gated_linear_unit:
        linear_fc1 = ColumnParallelLoRAMergeLinear
    else:
        linear_fc1 = ColumnParallelLoRALinear

    mlp = ModuleSpec(
        module=MLP,
        submodules=MLPSubmodules(
            linear_fc1=linear_fc1,
            linear_fc2=RowParallelLoRALinear,
        ),
    )
    if is_qwen2p5:
        tf_layer = Qwen2p5VitTransformerLayer
    else:
        tf_layer = TransformerLayer

    return ModuleSpec(
        module=tf_layer,
        submodules=TransformerLayerSubmodules(
            input_layernorm=TENorm,
            self_attention=ModuleSpec(
                module=PackedSelfAttention,
                params={"attn_mask_type": attn_mask_type},
                submodules=SelfAttentionSubmodules(
                    linear_qkv=ColumnParallelLoRAQKVLinear,
                                                              
                    core_attention=DotProductAttention,
                    linear_proj=RowParallelLoRALinear,
                    q_layernorm=IdentityOp,
                    k_layernorm=IdentityOp,
                ),
            ),
            self_attn_bda=get_bias_dropout_add,
            pre_mlp_layernorm=TENorm,
            mlp=mlp,
            mlp_bda=get_bias_dropout_add,
            sharded_state_dict_keys_map={
                'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
                'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
            },
        ),
    )


                                                
def get_proj_mlp_module_spec_lora(use_te: bool = True, add_norm: bool = True) -> ModuleSpec:
                                     
    if use_te:
        if add_norm:
                                                                    
            return ModuleSpec(
                module=MLP,
                submodules=MLPSubmodules(
                    linear_fc1=TELayerNormColumnParallelLoRALinear,
                    linear_fc2=TERowParallelLoRALinear,
                ),
            )
        else:
            return ModuleSpec(
                module=MLP,
                submodules=MLPSubmodules(
                    linear_fc1=TEColumnParallelLoRALinear,
                    linear_fc2=TERowParallelLoRALinear,
                ),
            )
    else:
        assert not add_norm, "add_norm is not supported for non-TE MLP now"
        return ModuleSpec(
            module=MLP,
            submodules=MLPSubmodules(
                linear_fc1=ColumnParallelLoRALinear,
                linear_fc2=RowParallelLoRALinear,
            ),
        )
