              
                                                      
                       

import re
from turtle import mode
import torch

from megatron.core import mpu
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
from megatron.core.utils import get_model_config

from gpatch.core import parallel_state as gmpu
from gpatch.core.models.gpt.weight_conversion.utils import decoder_mcore_to_hf_weights
from gpatch.core.models.gpt.weight_conversion.qwen import qwen_mcore_to_hf_weights

qwen2vl_vit_proj_model_dict = {
    "patch_embed.proj.weight": "patch_embed.proj.weight",
    "decoder.final_layernorm.weight": "merger.ln_q.weight",
    "decoder.final_layernorm.bias": "merger.ln_q.bias",
          
    "self_attention.linear_proj.weight": "attn.proj.weight",
    "self_attention.linear_proj.bias": "attn.proj.bias",
    "self_attention.linear_qkv.weight": "attn.qkv.weight",
    "self_attention.linear_qkv.bias": "attn.qkv.bias",
         
    "mlp.linear_fc2.weight": "mlp.fc2.weight",
    "mlp.linear_fc2.bias": "mlp.fc2.bias",
    "mlp.linear_fc1.weight": "mlp.fc1.weight",
    "mlp.linear_fc1.bias": "mlp.fc1.bias",
          
    "self_attention.linear_qkv.layer_norm_weight": "norm1.weight",
    "self_attention.linear_qkv.layer_norm_bias": "norm1.bias",
    "mlp.linear_fc1.layer_norm_weight": "norm2.weight",
    "mlp.linear_fc1.layer_norm_bias": "norm2.bias",
               
    "projection.encoder.linear_fc1.weight": "merger.mlp.0.weight",
    "projection.encoder.linear_fc1.bias": "merger.mlp.0.bias",
    "projection.encoder.linear_fc2.weight": "merger.mlp.2.weight",
    "projection.encoder.linear_fc2.bias": "merger.mlp.2.bias",
}


def qwen2vl_vit_proj_mcore_to_hf_weights(
    mlm_model,
    unwrap_model_func=None,
    early_swap_model=False,
    cpu_memory_model=None,
    cpu_memory_model_name_prefix="",
):
    if unwrap_model_func is not None:
        mlm_model = unwrap_model_func(mlm_model)[0]
    model_config = get_model_config(mlm_model)
    layer_offset = get_transformer_layer_offset(model_config)
    assert layer_offset == 0, "--encoder-pipeline-model-parallel-size暂不支持大于1"

    tp_size = mpu.get_tensor_model_parallel_world_size()
    is_dp_and_cp_head = mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0
    if not is_dp_and_cp_head:
        return

    print(f"layer_offset {layer_offset}")

    hf_prefix = "visual"
    for pname, params in mlm_model.named_parameters():
                                  
        if early_swap_model:
            params = cpu_memory_model[f"{cpu_memory_model_name_prefix}{pname}"]
            layer_weight = params.data
            assert not layer_weight.is_cuda
            layer_weight = layer_weight.to(device=torch.cuda.current_device())
        else:
            layer_weight = params.data
        shape = layer_weight.shape

        if mpu.get_tensor_model_parallel_rank() == 0:
            gather_list = [torch.empty_like(layer_weight) for _ in range(tp_size)]
        else:
            gather_list = None
        dst = gmpu.update_weights_gather_dst_rank()
        torch.distributed.gather(layer_weight,
                                 gather_list,
                                 dst=dst,
                                 group=mpu.get_tensor_model_parallel_group())
        if mpu.get_tensor_model_parallel_rank() != 0:
            continue

        if pname in [
                "patch_embed.proj.weight",
                'decoder.final_layernorm.weight',
                'decoder.final_layernorm.bias',
                "projection.encoder.linear_fc2.bias",
        ]:
            hf_w_name = f'{hf_prefix}.{qwen2vl_vit_proj_model_dict[pname]}'
            hf_w = gather_list[0]
            yield (hf_w_name, hf_w)
        elif pname in [
                "projection.encoder.linear_fc1.weight",
                "projection.encoder.linear_fc1.bias",
        ]:
            hf_w_name = f'{hf_prefix}.{qwen2vl_vit_proj_model_dict[pname]}'
            hf_w = torch.cat(gather_list, dim=0)
            yield (hf_w_name, hf_w)
        elif pname in ["projection.encoder.linear_fc2.weight"]:
            hf_w_name = f'{hf_prefix}.{qwen2vl_vit_proj_model_dict[pname]}'
            hf_w = torch.cat(gather_list, dim=1)
            yield (hf_w_name, hf_w)
        else:
            decoder_prefix = f'{hf_prefix}.blocks'
            yield from decoder_mcore_to_hf_weights(pname, gather_list, shape, layer_offset,
                                                   qwen2vl_vit_proj_model_dict, model_config,
                                                   decoder_prefix, "deinterleave")


qwen2p5vl_vit_proj_model_dict = {
    "patch_embed.proj.weight": "patch_embed.proj.weight",
    "decoder.final_layernorm.weight": "merger.ln_q.weight",
          
    "self_attention.linear_proj.weight": "attn.proj.weight",
    "self_attention.linear_proj.bias": "attn.proj.bias",
    "self_attention.linear_qkv.weight": "attn.qkv.weight",
    "self_attention.linear_qkv.bias": "attn.qkv.bias",
         
    "mlp.linear_fc1.weight": [
        "mlp.gate_proj.weight",
        "mlp.up_proj.weight",
    ],
    "mlp.linear_fc1.bias": [
        "mlp.gate_proj.bias",
        "mlp.up_proj.bias",
    ],
    "mlp.linear_fc2.weight": "mlp.down_proj.weight",
    "mlp.linear_fc2.bias": "mlp.down_proj.bias",
          
    "self_attention.linear_qkv.layer_norm_weight": "norm1.weight",
    "mlp.linear_fc1.layer_norm_weight": "norm2.weight",
               
    "projection.encoder.linear_fc1.weight": "merger.mlp.0.weight",
    "projection.encoder.linear_fc1.bias": "merger.mlp.0.bias",
    "projection.encoder.linear_fc2.weight": "merger.mlp.2.weight",
    "projection.encoder.linear_fc2.bias": "merger.mlp.2.bias",
}


def qwen2p5vl_vit_proj_mcore_to_hf_weights(
    mlm_model,
    unwrap_model_func=None,
    early_swap_model=False,
    cpu_memory_model=None,
    cpu_memory_model_name_prefix="",
):
    if unwrap_model_func is not None:
        mlm_model = unwrap_model_func(mlm_model)[0]
    model_config = get_model_config(mlm_model)
    layer_offset = get_transformer_layer_offset(model_config)
    assert layer_offset == 0, "--encoder-pipeline-model-parallel-size暂不支持大于1"

    tp_size = mpu.get_tensor_model_parallel_world_size()
    is_dp_and_cp_head = mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0
    if not is_dp_and_cp_head:
        return

    print(f"layer_offset {layer_offset}")

    hf_prefix = "visual"
    for pname, params in mlm_model.named_parameters():
                                  
        if early_swap_model:
            params = cpu_memory_model[f"{cpu_memory_model_name_prefix}{pname}"]
            layer_weight = params.data
            assert not layer_weight.is_cuda
            layer_weight = layer_weight.to(device=torch.cuda.current_device())
        else:
            layer_weight = params.data
        shape = layer_weight.shape

        if mpu.get_tensor_model_parallel_rank() == 0:
            gather_list = [torch.empty_like(layer_weight) for _ in range(tp_size)]
        else:
            gather_list = None
        dst = gmpu.update_weights_gather_dst_rank()
        torch.distributed.gather(layer_weight,
                                 gather_list,
                                 dst=dst,
                                 group=mpu.get_tensor_model_parallel_group())
        if mpu.get_tensor_model_parallel_rank() != 0:
            continue

        if pname in [
                "patch_embed.proj.weight",
                'decoder.final_layernorm.weight',
                "projection.encoder.linear_fc2.bias",
        ]:
            hf_w_name = f'{hf_prefix}.{qwen2p5vl_vit_proj_model_dict[pname]}'
            hf_w = gather_list[0]
            yield (hf_w_name, hf_w)
        elif pname in [
                "projection.encoder.linear_fc1.weight",
                "projection.encoder.linear_fc1.bias",
        ]:
            hf_w_name = f'{hf_prefix}.{qwen2p5vl_vit_proj_model_dict[pname]}'
            hf_w = torch.cat(gather_list, dim=0)
            yield (hf_w_name, hf_w)
        elif pname in ["projection.encoder.linear_fc2.weight"]:
            hf_w_name = f'{hf_prefix}.{qwen2p5vl_vit_proj_model_dict[pname]}'
            hf_w = torch.cat(gather_list, dim=1)
            yield (hf_w_name, hf_w)
        else:
            decoder_prefix = f'{hf_prefix}.blocks'
            yield from decoder_mcore_to_hf_weights(pname, gather_list, shape, layer_offset,
                                                   qwen2p5vl_vit_proj_model_dict, model_config,
                                                   decoder_prefix, "deinterleave")


def qwen2vl_mcore_to_hf_weights(
    mlm_model,
    unwrap_model_func=None,
    early_swap_model=False,
    cpu_memory_model=None,
):
    if unwrap_model_func is not None:
        mlm_model = unwrap_model_func(mlm_model)[0]
    if getattr(mlm_model, "vision_model", None) is not None:
        yield from qwen2vl_vit_proj_mcore_to_hf_weights(
            mlm_model.vision_model,
            early_swap_model=early_swap_model,
            cpu_memory_model=cpu_memory_model,
            cpu_memory_model_name_prefix="vision_model.",
        )
    yield from qwen_mcore_to_hf_weights(
        mlm_model.language_model,
        None,
        early_swap_model=early_swap_model,
        cpu_memory_model=cpu_memory_model,
        cpu_memory_model_name_prefix="language_model.",
    )


def qwen2p5vl_mcore_to_hf_weights(
    mlm_model,
    unwrap_model_func=None,
    early_swap_model=False,
    cpu_memory_model=None,
):
    if unwrap_model_func is not None:
        mlm_model = unwrap_model_func(mlm_model)[0]
    if getattr(mlm_model, "vision_model", None) is not None:
        yield from qwen2p5vl_vit_proj_mcore_to_hf_weights(
            mlm_model.vision_model,
            early_swap_model=early_swap_model,
            cpu_memory_model=cpu_memory_model,
            cpu_memory_model_name_prefix="vision_model.",
        )
    yield from qwen_mcore_to_hf_weights(
        mlm_model.language_model,
        None,
        early_swap_model=early_swap_model,
        cpu_memory_model=cpu_memory_model,
        cpu_memory_model_name_prefix="language_model.",
    )
