# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Registry module for model architecture components.
"""

from enum import Enum
from typing import Callable

import torch
import torch.nn as nn

from .config_converter import (
    PretrainedConfig,
    TransformerConfig,
    hf_to_mcore_config_dense,
    hf_to_mcore_config_dpskv3,
    hf_to_mcore_config_llama4,
    hf_to_mcore_config_mixtral,
    hf_to_mcore_config_qwen2_5_vl,
    hf_to_mcore_config_qwen2moe,
    hf_to_mcore_config_qwen3moe,
)
from .model_forward import (
    gptmodel_forward,
    gptmodel_forward_qwen2_5_vl,
)
from .model_forward_fused import (
    fused_forward_gptmodel,
    fused_forward_qwen2_5_vl,
)
from .model_initializer import (
    BaseModelInitializer,
    DeepseekV3Model,
    DenseModel,
    MixtralModel,
    Qwen2MoEModel,
    Qwen3MoEModel,
    Qwen25VLModel,
)
from .weight_converter import (
    McoreToHFWeightConverterDense,
    McoreToHFWeightConverterDpskv3,
    McoreToHFWeightConverterMixtral,
    McoreToHFWeightConverterQwen2_5_VL,
    McoreToHFWeightConverterQwen2Moe,
    McoreToHFWeightConverterQwen3Moe,
)


class SupportedModel(Enum):
    LLAMA = "LlamaForCausalLM"  # tested
    QWEN2 = "Qwen2ForCausalLM"  # tested
    QWEN2_MOE = "Qwen2MoeForCausalLM"  # pending
    DEEPSEEK_V3 = "DeepseekV3ForCausalLM"  # not tested
    MIXTRAL = "MixtralForCausalLM"  # tested
    QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration"  # not supported
    LLAMA4 = "Llama4ForConditionalGeneration"  # not tested
    QWEN3 = "Qwen3ForCausalLM"  # tested
    QWEN3_MOE = "Qwen3MoeForCausalLM"  # not tested


# Registry for model configuration converters
MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
    SupportedModel.LLAMA: hf_to_mcore_config_dense,
    SupportedModel.QWEN2: hf_to_mcore_config_dense,
    SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,
    SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3,
    SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral,
    SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
    SupportedModel.LLAMA4: hf_to_mcore_config_llama4,
    SupportedModel.QWEN3: hf_to_mcore_config_dense,
    SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
    SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
}

# Registry for model initializers
MODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = {
    SupportedModel.LLAMA: DenseModel,
    SupportedModel.QWEN2: DenseModel,
    SupportedModel.QWEN2_MOE: Qwen2MoEModel,
    SupportedModel.MIXTRAL: MixtralModel,
    SupportedModel.DEEPSEEK_V3: DeepseekV3Model,
    SupportedModel.QWEN2_5_VL: Qwen25VLModel,
    SupportedModel.LLAMA4: DenseModel,
    SupportedModel.QWEN3: DenseModel,
    SupportedModel.QWEN3_MOE: Qwen3MoEModel,
    SupportedModel.QWEN2_5_VL: Qwen25VLModel,
}

# Registry for model forward functions
MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {
    SupportedModel.LLAMA: gptmodel_forward,
    SupportedModel.QWEN2: gptmodel_forward,
    SupportedModel.QWEN2_MOE: gptmodel_forward,
    SupportedModel.MIXTRAL: gptmodel_forward,
    SupportedModel.DEEPSEEK_V3: gptmodel_forward,
    SupportedModel.QWEN2_5_VL: gptmodel_forward,
    SupportedModel.LLAMA4: gptmodel_forward,
    SupportedModel.QWEN3: gptmodel_forward,
    SupportedModel.QWEN3_MOE: gptmodel_forward,
    SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
    SupportedModel.DEEPSEEK_V3: gptmodel_forward,
}

# Registry for model forward functions
MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
    SupportedModel.LLAMA: fused_forward_gptmodel,
    SupportedModel.QWEN2: fused_forward_gptmodel,
    SupportedModel.QWEN2_MOE: fused_forward_gptmodel,
    SupportedModel.MIXTRAL: fused_forward_gptmodel,
    SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,
    SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl,
    SupportedModel.LLAMA4: fused_forward_gptmodel,
    SupportedModel.QWEN3: fused_forward_gptmodel,
    SupportedModel.QWEN3_MOE: fused_forward_gptmodel,
    SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl,
    SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,
}

# Registry for model weight converters
MODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = {
    SupportedModel.LLAMA: McoreToHFWeightConverterDense,
    SupportedModel.QWEN2: McoreToHFWeightConverterDense,
    SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,
    SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral,
    SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3,
    SupportedModel.QWEN3: McoreToHFWeightConverterDense,
    SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,
    SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL,
}


def get_supported_model(model_type: str) -> SupportedModel:
    try:
        return SupportedModel(model_type)
    except ValueError as err:
        supported_models = [e.value for e in SupportedModel]
        raise NotImplementedError(
            f"Model Type: {model_type} not supported. Supported models: {supported_models}"
        ) from err


def hf_to_mcore_config(
    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
) -> TransformerConfig:
    """Convert huggingface PretrainedConfig to mcore TransformerConfig.

    Args:
        hf_config: The huggingface PretrainedConfig.
        dtype: The dtype of the model.
        **override_transformer_config_kwargs: The kwargs to override the transformer config.

    Returns:
        The mcore TransformerConfig.
    """
    assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
    model = get_supported_model(hf_config.architectures[0])
    return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs)


def init_mcore_model(
    tfconfig: TransformerConfig,
    hf_config: PretrainedConfig,
    pre_process: bool = True,
    post_process: bool = None,
    *,
    share_embeddings_and_output_weights: bool = False,
    value: bool = False,
    **extra_kwargs,  # may be used for vlm and moe
) -> nn.Module:
    """
    Initialize a Mcore model.

    Args:
        tfconfig: The transformer config.
        hf_config: The HuggingFace config.
        pre_process: Optional pre-processing function.
        post_process: Optional post-processing function.
        share_embeddings_and_output_weights: Whether to share embeddings and output weights.
        value: Whether to use value.
        **extra_kwargs: Additional keyword arguments.

    Returns:
        The initialized model.
    """
    assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
    model = get_supported_model(hf_config.architectures[0])
    initializer_cls = MODEL_INITIALIZER_REGISTRY[model]
    initializer = initializer_cls(tfconfig, hf_config)
    return initializer.initialize(
        pre_process=pre_process,
        post_process=post_process,
        share_embeddings_and_output_weights=share_embeddings_and_output_weights,
        value=value,
        **extra_kwargs,
    )


def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:
    """
    Get the forward function for given model architecture.
    """
    assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
    model = get_supported_model(hf_config.architectures[0])
    return MODEL_FORWARD_REGISTRY[model]


def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable:
    """
    Get the forward function for given model architecture.
    """
    assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
    model = get_supported_model(hf_config.architectures[0])
    return MODEL_FORWARD_FUSED_REGISTRY[model]


def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable:
    """
    Get the weight converter for given model architecture.
    """
    assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
    model = get_supported_model(hf_config.architectures[0])
    tfconfig = hf_to_mcore_config(hf_config, dtype)
    return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig)
