

from enum import Enum
from typing import Callable

import torch
import torch.nn as nn

from .config_converter import (
    PretrainedConfig,
    TransformerConfig,
    hf_to_mcore_config_dense,
    hf_to_mcore_config_dpskv3,
    hf_to_mcore_config_llama4,
    hf_to_mcore_config_mixtral,
    hf_to_mcore_config_qwen2_5_vl,
    hf_to_mcore_config_qwen2moe,
    hf_to_mcore_config_qwen3moe,
)
from .model_forward import (
    gptmodel_forward,
    gptmodel_forward_qwen2_5_vl,
)
from .model_forward_fused import (
    fused_forward_gptmodel,
    fused_forward_qwen2_5_vl,
)
from .model_initializer import (
    BaseModelInitializer,
    DeepseekV3Model,
    DenseModel,
    MixtralModel,
    Qwen2MoEModel,
    Qwen3MoEModel,
    Qwen25VLModel,
)
from .weight_converter import (
    McoreToHFWeightConverterDense,
    McoreToHFWeightConverterDpskv3,
    McoreToHFWeightConverterMixtral,
    McoreToHFWeightConverterQwen2_5_VL,
    McoreToHFWeightConverterQwen2Moe,
    McoreToHFWeightConverterQwen3Moe,
)

class SupportedModel(Enum):
    LLAMA = "LlamaForCausalLM"
    QWEN2 = "Qwen2ForCausalLM"
    QWEN2_MOE = "Qwen2MoeForCausalLM"
    DEEPSEEK_V3 = "DeepseekV3ForCausalLM"
    MIXTRAL = "MixtralForCausalLM"
    QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration"
    LLAMA4 = "Llama4ForConditionalGeneration"
    QWEN3 = "Qwen3ForCausalLM"
    QWEN3_MOE = "Qwen3MoeForCausalLM"

MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = {
    SupportedModel.LLAMA: hf_to_mcore_config_dense,
    SupportedModel.QWEN2: hf_to_mcore_config_dense,
    SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe,
    SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3,
    SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral,
    SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
    SupportedModel.LLAMA4: hf_to_mcore_config_llama4,
    SupportedModel.QWEN3: hf_to_mcore_config_dense,
    SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
    SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl,
}

MODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = {
    SupportedModel.LLAMA: DenseModel,
    SupportedModel.QWEN2: DenseModel,
    SupportedModel.QWEN2_MOE: Qwen2MoEModel,
    SupportedModel.MIXTRAL: MixtralModel,
    SupportedModel.DEEPSEEK_V3: DeepseekV3Model,
    SupportedModel.QWEN2_5_VL: Qwen25VLModel,
    SupportedModel.LLAMA4: DenseModel,
    SupportedModel.QWEN3: DenseModel,
    SupportedModel.QWEN3_MOE: Qwen3MoEModel,
    SupportedModel.QWEN2_5_VL: Qwen25VLModel,
}

MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {
    SupportedModel.LLAMA: gptmodel_forward,
    SupportedModel.QWEN2: gptmodel_forward,
    SupportedModel.QWEN2_MOE: gptmodel_forward,
    SupportedModel.MIXTRAL: gptmodel_forward,
    SupportedModel.DEEPSEEK_V3: gptmodel_forward,
    SupportedModel.QWEN2_5_VL: gptmodel_forward,
    SupportedModel.LLAMA4: gptmodel_forward,
    SupportedModel.QWEN3: gptmodel_forward,
    SupportedModel.QWEN3_MOE: gptmodel_forward,
    SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
    SupportedModel.DEEPSEEK_V3: gptmodel_forward,
}

MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
    SupportedModel.LLAMA: fused_forward_gptmodel,
    SupportedModel.QWEN2: fused_forward_gptmodel,
    SupportedModel.QWEN2_MOE: fused_forward_gptmodel,
    SupportedModel.MIXTRAL: fused_forward_gptmodel,
    SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,
    SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl,
    SupportedModel.LLAMA4: fused_forward_gptmodel,
    SupportedModel.QWEN3: fused_forward_gptmodel,
    SupportedModel.QWEN3_MOE: fused_forward_gptmodel,
    SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl,
    SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,
}

MODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = {
    SupportedModel.LLAMA: McoreToHFWeightConverterDense,
    SupportedModel.QWEN2: McoreToHFWeightConverterDense,
    SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,
    SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral,
    SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3,
    SupportedModel.QWEN3: McoreToHFWeightConverterDense,
    SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,
    SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL,
}

def get_supported_model(model_type: str) -> SupportedModel:
    try:
        return SupportedModel(model_type)
    except ValueError as err:
        supported_models = [e.value for e in SupportedModel]
        raise NotImplementedError(
            f"Model Type: {model_type} not supported. Supported models: {supported_models}"
        ) from err

def hf_to_mcore_config(
    hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
) -> TransformerConfig:
    assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
    model = get_supported_model(hf_config.architectures[0])
    return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs)

def init_mcore_model(
    tfconfig: TransformerConfig,
    hf_config: PretrainedConfig,
    pre_process: bool = True,
    post_process: bool = None,
    *,
    share_embeddings_and_output_weights: bool = False,
    value: bool = False,
    **extra_kwargs,
) -> nn.Module:
    assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
    model = get_supported_model(hf_config.architectures[0])
    initializer_cls = MODEL_INITIALIZER_REGISTRY[model]
    initializer = initializer_cls(tfconfig, hf_config)
    return initializer.initialize(
        pre_process=pre_process,
        post_process=post_process,
        share_embeddings_and_output_weights=share_embeddings_and_output_weights,
        value=value,
        **extra_kwargs,
    )

def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:
    assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
    model = get_supported_model(hf_config.architectures[0])
    return MODEL_FORWARD_REGISTRY[model]

def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable:
    assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
    model = get_supported_model(hf_config.architectures[0])
    return MODEL_FORWARD_FUSED_REGISTRY[model]

def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable:
    assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
    model = get_supported_model(hf_config.architectures[0])
    tfconfig = hf_to_mcore_config(hf_config, dtype)
    return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig)
