              
                                                      
                       

import re
from turtle import mode
import torch

from megatron.core import mpu
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
from megatron.core.utils import get_model_config

from gpatch.core import parallel_state as gmpu
from gpatch.core.models.gpt.weight_conversion.utils import decoder_mcore_to_hf_weights

slip_model_dict = {
    "position_embeddings.weight":
    "embeddings.position_embedding.weight",
    "conv1.weight":
    "embeddings.patch_embedding.weight",
    "conv1.bias":
    "embeddings.patch_embedding.bias",
    "ln_post.weight":
    "post_layernorm.weight",
    "ln_post.bias":
    "post_layernorm.bias",
          
    "self_attention.linear_proj.weight":
    "self_attn.out_proj.weight",
    "self_attention.linear_proj.bias":
    "self_attn.out_proj.bias",
    "self_attention.linear_qkv.weight": [
        "self_attn.q_proj.weight",
        "self_attn.k_proj.weight",
        "self_attn.v_proj.weight",
    ],
    "self_attention.linear_qkv.bias": [
        "self_attn.q_proj.bias",
        "self_attn.k_proj.bias",
        "self_attn.v_proj.bias",
    ],
         
    "mlp.linear_fc2.weight":
    "mlp.fc2.weight",
    "mlp.linear_fc2.bias":
    "mlp.fc2.bias",
    "mlp.linear_fc1.weight":
    "mlp.fc1.weight",
    "mlp.linear_fc1.bias":
    "mlp.fc1.bias",
          
    "self_attention.linear_qkv.layer_norm_weight":
    "layer_norm1.weight",
    "self_attention.linear_qkv.layer_norm_bias":
    "layer_norm1.bias",
    "mlp.linear_fc1.layer_norm_weight":
    "layer_norm2.weight",
    "mlp.linear_fc1.layer_norm_bias":
    "layer_norm2.bias",
}


def slip_mcore_to_hf_weights(mlm_model, unwrap_model_func=None):
    if unwrap_model_func is not None:
        mlm_model = unwrap_model_func(mlm_model)[0]
    model_config = get_model_config(mlm_model)
    layer_offset = get_transformer_layer_offset(model_config)
    assert layer_offset == 0, "--encoder-pipeline-model-parallel-size暂不支持大于1"

    tp_size = mpu.get_tensor_model_parallel_world_size()
    is_dp_and_cp_head = mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0
    if not is_dp_and_cp_head:
        return

    print(f"layer_offset {layer_offset}")

    hf_prefix = "vision_tower.vision_model"
    for pname, params in mlm_model.named_parameters():
        layer_weight = params.data
        shape = layer_weight.shape

        if mpu.get_tensor_model_parallel_rank() == 0:
            gather_list = [torch.empty_like(layer_weight) for _ in range(tp_size)]
        else:
            gather_list = None
        dst = gmpu.update_weights_gather_dst_rank()
        torch.distributed.gather(layer_weight,
                                 gather_list,
                                 dst=dst,
                                 group=mpu.get_tensor_model_parallel_group())
        if mpu.get_tensor_model_parallel_rank() != 0:
            continue

        if pname in [
                "position_embeddings.weight",
                "conv1.weight",
                "conv1.bias",
                "ln_post.weight",
                "ln_post.bias",
        ]:
            hf_w_name = f'{hf_prefix}.{slip_model_dict[pname]}'
            hf_w = gather_list[0]
            yield (hf_w_name, hf_w)
        else:
            decoder_prefix = f'{hf_prefix}.encoder.layers'
            yield from decoder_mcore_to_hf_weights(pname, gather_list, shape, layer_offset,
                                                   slip_model_dict, model_config, decoder_prefix)


gemma3_projector_model_dict = {
    "mm_input_projection.layer_norm_weight": "mm_soft_emb_norm.weight",
    "mm_input_projection.weight": "mm_input_projection_weight",
}


def gemma3_proj_mcore_to_hf_weights(mlm_model, unwrap_model_func=None):
    if unwrap_model_func is not None:
        mlm_model = unwrap_model_func(mlm_model)[0]

    tp_size = mpu.get_tensor_model_parallel_world_size()
    is_dp_and_cp_head = mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0
    if not is_dp_and_cp_head:
        return

    hf_prefix = "multi_modal_projector"
    for pname, params in mlm_model.named_parameters():
        layer_weight = params.data

        if mpu.get_tensor_model_parallel_rank() == 0:
            gather_list = [torch.empty_like(layer_weight) for _ in range(tp_size)]
        else:
            gather_list = None
        dst = gmpu.update_weights_gather_dst_rank()
        torch.distributed.gather(layer_weight,
                                 gather_list,
                                 dst=dst,
                                 group=mpu.get_tensor_model_parallel_group())
        if mpu.get_tensor_model_parallel_rank() != 0:
            continue

        if pname in ["mm_input_projection.layer_norm_weight"]:
            hf_w_name = f'{hf_prefix}.{gemma3_projector_model_dict[pname]}'
            hf_w = gather_list[0]
            yield (hf_w_name, hf_w)
        elif pname in ["mm_input_projection.weight"]:
            hf_w_name = f'{hf_prefix}.{gemma3_projector_model_dict[pname]}'
            for i in range(len(gather_list)):
                gather_list[i] = gather_list[i].T
            hf_w = torch.cat(gather_list, dim=-1)
            yield (hf_w_name, hf_w)
        else:
            raise NotImplementedError


gemma3_text_model_dict = {
    "embedding.word_embeddings.weight":
    "embed_tokens.weight",
    "decoder.final_layernorm.weight":
    "norm.weight",
          
    "self_attention.linear_proj.weight":
    "self_attn.o_proj.weight",
    "self_attention.linear_qkv.weight": [
        "self_attn.q_proj.weight",
        "self_attn.k_proj.weight",
        "self_attn.v_proj.weight",
    ],
         
    "mlp.linear_fc1.weight": [
        "mlp.gate_proj.weight",
        "mlp.up_proj.weight",
    ],
    "mlp.linear_fc2.weight":
    "mlp.down_proj.weight",
          
    "self_attention.linear_qkv.layer_norm_weight":
    "input_layernorm.weight",
    "mlp.linear_fc1.layer_norm_weight":
    "pre_feedforward_layernorm.weight",
    "post_attention_layernorm.weight":
    "post_attention_layernorm.weight",
    "post_feedforward_layernorm.weight":
    "post_feedforward_layernorm.weight",
    "self_attention.q_layernorm.weight":
    "self_attn.q_norm.weight",
    "self_attention.k_layernorm.weight":
    "self_attn.k_norm.weight",
}


def gemma3_text_mcore_to_hf_weights(mlm_model, unwrap_model_func=None):
    if unwrap_model_func is not None:
        mlm_model = unwrap_model_func(mlm_model)[0]
    model_config = get_model_config(mlm_model)
    layer_offset = get_transformer_layer_offset(model_config)
    vocab_size = model_config.hf_vocab_size

    tp_size = mpu.get_tensor_model_parallel_world_size()
    is_dp_and_cp_head = mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0
    if not is_dp_and_cp_head:
        return

    print(f"layer_offset {layer_offset}")

    hf_prefix = "language_model.model"
    for pname, params in mlm_model.named_parameters():
        layer_weight = params.data
        shape = layer_weight.shape

        if mpu.get_tensor_model_parallel_rank() == 0:
            gather_list = [torch.empty_like(layer_weight) for _ in range(tp_size)]
        else:
            gather_list = None
        dst = gmpu.update_weights_gather_dst_rank()
        torch.distributed.gather(layer_weight,
                                 gather_list,
                                 dst=dst,
                                 group=mpu.get_tensor_model_parallel_group())
        if mpu.get_tensor_model_parallel_rank() != 0:
            continue

        if pname in ["embedding.word_embeddings.weight"]:
            mlm_embed = torch.cat(gather_list, dim=0)
            mlm_emebd_size = mlm_embed.shape[0]
            assert mlm_emebd_size >= vocab_size, f"{mlm_emebd_size=} {vocab_size=}"
            hf_w = mlm_embed[:vocab_size]
            hf_w_name = f'{hf_prefix}.{gemma3_text_model_dict[pname]}'
            yield (hf_w_name, hf_w)
        elif pname in ["decoder.final_layernorm.weight"]:
            hf_w_name = f'{hf_prefix}.{gemma3_text_model_dict[pname]}'
            hf_w = gather_list[0]
            yield (hf_w_name, hf_w)
        elif pname in ["output_layer.weight"]:
            assert mlm_model.share_embeddings_and_output_weights
            continue
        else:
            decoder_prefix = f'{hf_prefix}.layers'
            yield from decoder_mcore_to_hf_weights(pname, gather_list, shape, layer_offset,
                                                   gemma3_text_model_dict, model_config,
                                                   decoder_prefix)


def gemma3_mcore_to_hf_weights(mlm_model, unwrap_model_func=None):
    if unwrap_model_func is not None:
        mlm_model = unwrap_model_func(mlm_model)[0]
    if getattr(mlm_model, "vision_model", None) is not None:
        yield from slip_mcore_to_hf_weights(mlm_model.vision_model)
    if getattr(mlm_model, "vision_projection", None) is not None:
        yield from gemma3_proj_mcore_to_hf_weights(mlm_model.vision_projection.encoder)
    yield from gemma3_text_mcore_to_hf_weights(mlm_model.language_model)
