import torch

from megatron.core import mpu
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
from megatron.core.utils import get_model_config


def mcore_to_hf_weights(mlm_model, unwrap_model_func,
                        update_params_type="dense",
                        early_swap_model=False,
                        cpu_memory_model=None):
    assert update_params_type in ["dense", "moe"]
                              
    unwrapped_model = unwrap_model_func(mlm_model)[0]
    model_config = get_model_config(unwrapped_model)

                             
                                                        
    if model_config.model_arch in ["qwen2vl"]:
        from gpatch.core.models.gpt.weight_conversion.qwen2vl import qwen2vl_mcore_to_hf_weights
        yield from qwen2vl_mcore_to_hf_weights(
            mlm_model,
            unwrap_model_func,
            early_swap_model=early_swap_model,
            cpu_memory_model=cpu_memory_model,
        )
    elif model_config.model_arch in ["qwen2.5vl"]:
        from gpatch.core.models.gpt.weight_conversion.qwen2vl import qwen2p5vl_mcore_to_hf_weights
        yield from qwen2p5vl_mcore_to_hf_weights(
            mlm_model,
            unwrap_model_func,
            early_swap_model=early_swap_model,
            cpu_memory_model=cpu_memory_model,
        )
    elif model_config.model_arch in ["gemma3"]:
        from gpatch.core.models.gpt.weight_conversion.gemma3 import gemma3_mcore_to_hf_weights
        yield from gemma3_mcore_to_hf_weights(mlm_model, unwrap_model_func)
    elif "qwen" in model_config.model_arch or "qwq" in model_config.model_arch or "llama" in model_config.model_arch:
        if "moe" in model_config.model_arch:
            assert model_config.model_arch in ["qwen3-moe"]
                                         
            from gpatch.core.models.gpt.weight_conversion.qwen import qwen_moe_mcore_to_hf_weights
                                                              
            yield from qwen_moe_mcore_to_hf_weights(mlm_model,
                                                    unwrap_model_func,
                                                    update_params_type=update_params_type,
                                                    early_swap_model=early_swap_model,
                                                    cpu_memory_model=cpu_memory_model)
        else:
            from gpatch.core.models.gpt.weight_conversion.qwen import qwen_mcore_to_hf_weights
            yield from qwen_mcore_to_hf_weights(mlm_model, unwrap_model_func,
                                                early_swap_model=early_swap_model,
                                                cpu_memory_model=cpu_memory_model)
    else:
        raise NotImplementedError(f"unknown {model_config.model_arch}")
