              
                                                      
                                                                 

import warnings
from typing import Optional

from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules

from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import (
    TransformerLayer,
    TransformerLayerSubmodules,
)
from megatron.core.utils import is_te_min_version

try:
    from megatron.core.extensions.transformer_engine import (
        TEDotProductAttention,
        TENorm,
    )

    HAVE_TE = True
except ImportError:
    HAVE_TE = False

from gpatch.core.tensor_parallel.lora import (
    TEColumnParallelLoRALinear,
    TERowParallelLoRALinear,
    TELayerNormColumnParallelLoRALinear,
    TELayerNormColumnParallelLoRAMergeLinear,
    TELayerNormColumnParallelLoRAQKVLinear,
    ColumnParallelLoRALinear,
    RowParallelLoRALinear,
    ColumnParallelLoRAMergeLinear,
    ColumnParallelLoRAQKVLinear
)


def get_gpt_layer_with_transformer_engine_spec_lora(
    num_experts: Optional[int] = None,
    moe_grouped_gemm: Optional[bool] = False,
    qk_layernorm: Optional[bool] = False,
    multi_latent_attention: Optional[bool] = False,
    fp8: Optional[str] = None,                                    
    moe_use_legacy_grouped_gemm: Optional[bool] = False,
    gated_linear_unit: bool = True,
) -> ModuleSpec:
    """ Args ref by get_gpt_layer_with_transformer_engine_spec """
    assert fp8 is None
    assert num_experts is None, f"Not supported yet, can be support later"
    assert not multi_latent_attention, f"Not supported yet, can be support later"
    assert is_te_min_version("1.9.0")

    mlp = get_mlp_module_spec_lora(
        use_te=True,
        num_experts=num_experts,
        moe_grouped_gemm=moe_grouped_gemm,
        moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
        gated_linear_unit=gated_linear_unit,
    )

    qk_norm = TENorm

    return ModuleSpec(
        module=TransformerLayer,
        submodules=TransformerLayerSubmodules(
            self_attention=ModuleSpec(
                module=SelfAttention,
                params={"attn_mask_type": AttnMaskType.causal},
                submodules=SelfAttentionSubmodules(
                    linear_qkv=TELayerNormColumnParallelLoRAQKVLinear,
                    core_attention=TEDotProductAttention,
                    linear_proj=TERowParallelLoRALinear,
                    q_layernorm=qk_norm if qk_layernorm else IdentityOp,
                    k_layernorm=qk_norm if qk_layernorm else IdentityOp,
                ),
            ),
            self_attn_bda=get_bias_dropout_add,
            pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
            mlp=mlp,
            mlp_bda=get_bias_dropout_add,
        ),
    )


def get_gpt_layer_with_local_spec_lora(
    num_experts: Optional[int] = None,
    moe_grouped_gemm: Optional[bool] = False,
    qk_layernorm: Optional[bool] = False,
    multi_latent_attention: Optional[bool] = False,
    fp8: Optional[str] = None,                                    
    moe_use_legacy_grouped_gemm: Optional[bool] = False,
    gated_linear_unit: bool = True,
) -> ModuleSpec:
    assert fp8 is None
    assert num_experts is None, f"Not supported yet, can be support later"
    assert not multi_latent_attention, f"Not supported yet, can be support later"

    mlp = get_mlp_module_spec_lora(
        use_te=False,
        num_experts=num_experts,
        moe_grouped_gemm=moe_grouped_gemm,
        moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
        gated_linear_unit=gated_linear_unit,
    )

    qk_norm = TENorm

    return ModuleSpec(
        module=TransformerLayer,
        submodules=TransformerLayerSubmodules(
            input_layernorm=TENorm,
            self_attention=ModuleSpec(
                module=SelfAttention,
                params={"attn_mask_type": AttnMaskType.causal},
                submodules=SelfAttentionSubmodules(
                    linear_qkv=ColumnParallelLoRAQKVLinear,
                    core_attention=DotProductAttention,
                    linear_proj=RowParallelLoRALinear,
                    q_layernorm=qk_norm if qk_layernorm else IdentityOp,
                    k_layernorm=qk_norm if qk_layernorm else IdentityOp,
                ),
            ),
            self_attn_bda=get_bias_dropout_add,
            pre_mlp_layernorm=TENorm,
            mlp=mlp,
            mlp_bda=get_bias_dropout_add,
        ),
    )


def get_mlp_module_spec_lora(
    use_te: Optional[bool] = True,
    num_experts: Optional[int] = None,
    moe_grouped_gemm: Optional[bool] = False,
    fp8: Optional[str] = None,                                    
    moe_use_legacy_grouped_gemm: Optional[bool] = False,
    gated_linear_unit: bool = True,
) -> ModuleSpec:
    """Helper function to get module spec for MLP/MoE"""
    assert fp8 is None
    assert num_experts is None, f"Not supported yet, can be support later"

    if use_te:
        if gated_linear_unit:
            linear_fc1 = TELayerNormColumnParallelLoRAMergeLinear
        else:
            linear_fc1 = TELayerNormColumnParallelLoRALinear

        return ModuleSpec(
            module=MLP,
            submodules=MLPSubmodules(
                linear_fc1=linear_fc1,
                linear_fc2=TERowParallelLoRALinear,
            ),
        )
    else:
        if gated_linear_unit:
            linear_fc1 = ColumnParallelLoRAMergeLinear
        else:
            linear_fc1 = ColumnParallelLoRALinear

        return ModuleSpec(
            module=MLP,
            submodules=MLPSubmodules(
                linear_fc1=linear_fc1,
                linear_fc2=RowParallelLoRALinear,
            ),
        )
