import re
import torch

from megatron.core import mpu
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
from megatron.core.utils import get_model_config

from gpatch.core import parallel_state as gmpu
from gpatch.core.utils import print_with_rank_and_datetime

qwen_model_dict = {
    "embedding.word_embeddings.weight":
    "embed_tokens.weight",
    "self_attention.linear_proj.weight":
    "self_attn.o_proj.weight",
    "self_attention.linear_qkv.layer_norm_weight":
    "input_layernorm.weight",
    "self_attention.linear_qkv.weight":
    ["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"],
    "self_attention.linear_qkv.bias": [
        "self_attn.q_proj.bias",
        "self_attn.k_proj.bias",
        "self_attn.v_proj.bias",
    ],
    "mlp.linear_fc1.layer_norm_weight":
    "post_attention_layernorm.weight",
    "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"],
    "mlp.linear_fc2.weight":
    "mlp.down_proj.weight",
    "final_layernorm.weight":
    "norm.weight",
    "output_layer.weight":
    "lm_head.weight",
                       
    "self_attention.q_layernorm.weight":
    "self_attn.q_norm.weight",
    "self_attention.k_layernorm.weight":
    "self_attn.k_norm.weight",
}

qwen_moe_model_dict = {
    "embedding.word_embeddings.weight":
    "embed_tokens.weight",
    "self_attention.linear_proj.weight":
    "self_attn.o_proj.weight",
    "self_attention.linear_qkv.layer_norm_weight":
    "input_layernorm.weight",
    "self_attention.linear_qkv.weight":
    ["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"],
    "self_attention.linear_qkv.bias": [
        "self_attn.q_proj.bias",
        "self_attn.k_proj.bias",
        "self_attn.v_proj.bias",
    ],
    "final_layernorm.weight":
    "norm.weight",
    "output_layer.weight":
    "lm_head.weight",
                       
    "self_attention.q_layernorm.weight":
    "self_attn.q_norm.weight",
    "self_attention.k_layernorm.weight":
    "self_attn.k_norm.weight",
                      
    "mlp.router.weight":
    "mlp.gate.weight",
    "pre_mlp_layernorm.weight":
    "post_attention_layernorm.weight",
    "mlp.experts.linear_fc1.weight": [
        "mlp.experts.{}.gate_proj.weight",
        "mlp.experts.{}.up_proj.weight",
    ],
    "mlp.experts.linear_fc2.weight":
    "mlp.experts.{}.down_proj.weight",
}


def qwen_mcore_to_hf_weights(
    mlm_model,
    unwrap_model_func=None,
    early_swap_model=False,
    cpu_memory_model=None,
    cpu_memory_model_name_prefix="",
):
                 
    if unwrap_model_func is not None:
        unwrapped_model = unwrap_model_func(mlm_model)[0]
    else:
        unwrapped_model = mlm_model
    model_config = get_model_config(unwrapped_model)

                                                                
          
         
                                                                       
         
                                             
                                    
    layer_offset = get_transformer_layer_offset(model_config)

    tp_size = mpu.get_tensor_model_parallel_world_size()
    num_layers = model_config.num_layers
    num_attn_heads = model_config.num_attention_heads
    num_kv_heads = model_config.num_query_groups
    hidden_size = model_config.hidden_size
    ffn_hidden_size = model_config.ffn_hidden_size
    dim = model_config.kv_channels
    hf_embed_vocab_size = model_config.hf_vocab_size
    attn_head_num_partion = num_attn_heads // tp_size
    kv_head_num_partion = num_kv_heads // tp_size
    kv_dims_partion = (2 * dim + (dim * num_attn_heads // num_kv_heads))

    is_dp_and_cp_head = mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0
    if not is_dp_and_cp_head:
        return
    print(
        f"layer_offset {layer_offset} {dim} {attn_head_num_partion=} {kv_head_num_partion=} {kv_dims_partion=}"
    )

    for pname, params in unwrapped_model.named_parameters():
                                  
        if early_swap_model:
            params = cpu_memory_model[f"{cpu_memory_model_name_prefix}{pname}"]
            layer_weight = params.data
            assert not layer_weight.is_cuda
            layer_weight = layer_weight.to(device=torch.cuda.current_device())
        else:
            layer_weight = params.data
        shape = layer_weight.shape
                                                                                                 

        if mpu.get_tensor_model_parallel_rank() == 0:
            gather_list = [torch.empty_like(layer_weight) for _ in range(tp_size)]
        else:
            gather_list = None
        dst = gmpu.update_weights_gather_dst_rank()
        torch.distributed.gather(layer_weight,
                                 gather_list,
                                 dst=dst,
                                 group=mpu.get_tensor_model_parallel_group())
        if mpu.get_tensor_model_parallel_rank() != 0:
            continue

        if pname.endswith("embedding.word_embeddings.weight"):
            mlm_embed = torch.cat(gather_list, dim=0)
            mlm_emebd_size = mlm_embed.shape[0]
            assert mlm_emebd_size >= hf_embed_vocab_size, f"{mlm_emebd_size=} {hf_embed_vocab_size=}"
            hf_w_name = f'model.{qwen_model_dict["embedding.word_embeddings.weight"]}'
            hf_w = mlm_embed[:hf_embed_vocab_size]
            yield (hf_w_name, hf_w)

        elif pname.endswith("final_layernorm.weight"):
                                             
            hf_w_name = f'model.{qwen_model_dict["final_layernorm.weight"]}'
            hf_w = gather_list[0]
            yield (hf_w_name, hf_w.detach().clone())

        elif pname.endswith("output_layer.weight"):
            output_layer = torch.cat(gather_list, dim=0)
            output_size = output_layer.shape[0]
            assert output_size >= hf_embed_vocab_size
            yield (f'{qwen_model_dict["output_layer.weight"]}', output_layer[:hf_embed_vocab_size])

        else:
            _, layer_idx, sub_name = re.split(r'\.(\d+)\.', pname)
            layer_numer = int(layer_idx) + layer_offset
            if sub_name in ["self_attention.linear_proj.weight", "mlp.linear_fc2.weight"]:
                hf_w_name = f'model.layers.{layer_numer}.{qwen_model_dict[sub_name]}'
                hf_w = torch.cat(gather_list, dim=1)
                yield (hf_w_name, hf_w)

            elif sub_name in ["mlp.linear_fc1.weight"]:
                shard_size = shape[0] // 2
                w1 = [weight[:shard_size, :] for weight in gather_list]
                w2 = [weight[shard_size:, :] for weight in gather_list]
                yield (f"model.layers.{layer_numer}.{qwen_model_dict[sub_name][0]}",
                       torch.cat(w1, dim=0))
                yield (f"model.layers.{layer_numer}.{qwen_model_dict[sub_name][1]}",
                       torch.cat(w2, dim=0))

            elif sub_name in ["self_attention.linear_qkv.weight"]:
                wq, wk, wv = [], [], []
                for weight in gather_list:
                    weight = weight.reshape((kv_head_num_partion, kv_dims_partion, -1))
                    q_end_idx = (dim * attn_head_num_partion // kv_head_num_partion)
                    wq.append(
                        weight.narrow(1, 0, q_end_idx).reshape((dim * attn_head_num_partion, -1)))
                    wk.append(
                        weight.narrow(1, q_end_idx, dim).reshape((dim * kv_head_num_partion, -1)))
                    wv.append(
                        weight.narrow(1, q_end_idx + dim, dim).reshape(
                            (dim * kv_head_num_partion, -1)))
                yield (f"model.layers.{layer_numer}.{qwen_model_dict[sub_name][0]}",
                       torch.cat(wq, dim=0))
                yield (f"model.layers.{layer_numer}.{qwen_model_dict[sub_name][1]}",
                       torch.cat(wk, dim=0))
                yield (f"model.layers.{layer_numer}.{qwen_model_dict[sub_name][2]}",
                       torch.cat(wv, dim=0))

            elif sub_name in ["self_attention.linear_qkv.bias"]:
                bq, bk, bv = [], [], []
                for mlm_bias in gather_list:
                    mlm_bias = mlm_bias.reshape((kv_head_num_partion, kv_dims_partion))
                    q_end_idx = (dim * attn_head_num_partion // kv_head_num_partion)
                    bq.append(
                        mlm_bias.narrow(1, 0, q_end_idx).reshape((dim * attn_head_num_partion)))
                    bk.append(
                        mlm_bias.narrow(1, q_end_idx, dim).reshape((dim * kv_head_num_partion)))
                    bv.append(
                        mlm_bias.narrow(1, q_end_idx + dim, dim).reshape(
                            (dim * kv_head_num_partion)))
                yield (f"model.layers.{layer_numer}.{qwen_model_dict[sub_name][0]}",
                       torch.cat(bq, dim=0))
                yield (f"model.layers.{layer_numer}.{qwen_model_dict[sub_name][1]}",
                       torch.cat(bk, dim=0))
                yield (f"model.layers.{layer_numer}.{qwen_model_dict[sub_name][2]}",
                       torch.cat(bv, dim=0))
            elif sub_name in [
                    "self_attention.linear_qkv.layer_norm_weight",
                    "self_attention.q_layernorm.weight", "self_attention.k_layernorm.weight",
                    "mlp.linear_fc1.layer_norm_weight"
            ]:
                hf_w_name = f'model.layers.{layer_numer}.{qwen_model_dict[sub_name]}'
                hf_w = gather_list[0]
                yield (hf_w_name, hf_w.detach().clone())
            else:
                raise ValueError(f"unknown pname {layer_idx=} {sub_name=}")


def qwen_moe_mcore_to_hf_weights(mlm_model, unwrap_model_func, update_params_type,
                                 early_swap_model=False,
                                 cpu_memory_model=None):
                 
    assert update_params_type in ["dense", "moe"]
    unwrapped_model = unwrap_model_func(mlm_model)[0]
    model_config = get_model_config(unwrapped_model)
    layer_offset = get_transformer_layer_offset(model_config)

    tp_size = mpu.get_tensor_model_parallel_world_size()
    ep_size = mpu.get_expert_model_parallel_world_size()
    ep_rank = mpu.get_expert_model_parallel_rank()

    num_layers = model_config.num_layers
    num_attn_heads = model_config.num_attention_heads
    num_kv_heads = model_config.num_query_groups
    hidden_size = model_config.hidden_size
    ffn_hidden_size = model_config.ffn_hidden_size
    dim = model_config.kv_channels
                                  
    hf_embed_vocab_size = model_config.hf_vocab_size
    num_experts = model_config.num_experts
    attn_head_num_partion = num_attn_heads // tp_size
    kv_head_num_partion = num_kv_heads // tp_size
    kv_dims_partion = (2 * dim + (dim * num_attn_heads // num_kv_heads))
    num_local_experts = num_experts // ep_size

                   
    if ep_size <= 1:
        do_update = mpu.get_data_parallel_rank() == 0 and mpu.get_context_parallel_rank() == 0
    elif mpu.get_expert_tensor_parallel_world_size() != mpu.get_tensor_model_parallel_world_size():
        do_update = mpu.get_expert_data_parallel_rank() == 0 or \
            mpu.get_data_parallel_rank(with_context_parallel=True) == 0
    else:
        dp_with_cp_rank = mpu.get_data_parallel_rank(with_context_parallel=True)
        do_update = dp_with_cp_rank < ep_size

    if not do_update:
        return
    print_with_rank_and_datetime(
        f"layer_offset {layer_offset} {dim} {attn_head_num_partion=} {kv_head_num_partion=} {kv_dims_partion=}"
    )
    _model = unwrapped_model
    if early_swap_model:
        _model = cpu_memory_model

    for pname, params in _model.named_parameters():
        layer_weight = params.data
        shape = layer_weight.shape
        if update_params_type == "moe":
            if "mlp.experts" not in pname:
                continue
        else:
            if "mlp.experts" in pname:
                continue

                                  
        if early_swap_model:
            assert not layer_weight.is_cuda
            layer_weight = layer_weight.to(device=torch.cuda.current_device())
                 
        if mpu.get_expert_tensor_parallel_world_size() != mpu.get_tensor_model_parallel_world_size(
        ):
            if "mlp.experts" in pname:
                                                       
                if mpu.get_expert_data_parallel_rank() == 0:
                    weights = update_moe_weights(
                        pname,
                        layer_weight,
                        layer_offset=layer_offset,
                        ep_rank=ep_rank,
                        num_local_experts=num_local_experts,
                    )
                    for weight in weights:
                        yield weight
                continue
            else:
                                             
                if mpu.get_data_parallel_rank(with_context_parallel=True) > 0:
                    continue
        else:
            if ep_size > 1 and ep_rank > 0:
                if "mlp.experts.linear_fc" not in pname:
                    continue

                                                          
        if mpu.get_tensor_model_parallel_rank() == 0:
            gather_list = [torch.empty_like(layer_weight) for _ in range(tp_size)]
        else:
            gather_list = None
        dst = gmpu.update_weights_gather_dst_rank()
        try:
            torch.distributed.gather(layer_weight,
                                     gather_list,
                                     dst=dst,
                                     group=mpu.get_tensor_model_parallel_group())
        except Exception as e:
            print(
                f"ERROR rank {torch.distributed.get_rank()} {dst=} {mpu.get_expert_model_parallel_rank()=} {mpu.get_tensor_model_parallel_rank()=} {mpu.get_context_parallel_rank()=} {mpu.get_tensor_model_parallel_group()=} error {e}",
                flush=True)
            import sys
            sys.exit()
        if mpu.get_tensor_model_parallel_rank() != 0:
            continue

        if pname.endswith("embedding.word_embeddings.weight"):
            mlm_embed = torch.cat(gather_list, dim=0)
            mlm_emebd_size = mlm_embed.shape[0]
            assert mlm_emebd_size >= hf_embed_vocab_size, f"{mlm_emebd_size=} {hf_embed_vocab_size=}"
            yield (f'model.{qwen_moe_model_dict["embedding.word_embeddings.weight"]}',
                   mlm_embed[:hf_embed_vocab_size])
        elif pname.endswith("final_layernorm.weight"):
                                             
            yield (f'model.{qwen_moe_model_dict["final_layernorm.weight"]}',
                   gather_list[0].detach().clone())
        elif pname.endswith("output_layer.weight"):
            output_layer = torch.cat(gather_list, dim=0)
            output_size = output_layer.shape[0]
            assert output_size >= hf_embed_vocab_size

            yield (f'{qwen_moe_model_dict["output_layer.weight"]}',
                   output_layer[:hf_embed_vocab_size])
        else:
            try:
                _, layer_idx, sub_name = re.split(r'\.(\d+)\.', pname)
            except Exception as e:
                print(f"trace error {pname=} {e=}")
                raise Exception(e)

            layer_numer = int(layer_idx) + layer_offset
            if sub_name in ["self_attention.linear_proj.weight"]:
                yield (f'model.layers.{layer_numer}.{qwen_moe_model_dict[sub_name]}',
                       torch.cat(gather_list, dim=1))
            elif sub_name in ["self_attention.linear_qkv.weight"]:
                wq, wk, wv = [], [], []
                for weight in gather_list:
                    weight = weight.reshape((kv_head_num_partion, kv_dims_partion, -1))
                    q_end_idx = (dim * attn_head_num_partion // kv_head_num_partion)
                    wq.append(
                        weight.narrow(1, 0, q_end_idx).reshape((dim * attn_head_num_partion, -1)))
                    wk.append(
                        weight.narrow(1, q_end_idx, dim).reshape((dim * kv_head_num_partion, -1)))
                    wv.append(
                        weight.narrow(1, q_end_idx + dim, dim).reshape(
                            (dim * kv_head_num_partion, -1)))
                yield (f"model.layers.{layer_numer}.{qwen_moe_model_dict[sub_name][0]}",
                       torch.cat(wq, dim=0))
                yield (f"model.layers.{layer_numer}.{qwen_moe_model_dict[sub_name][1]}",
                       torch.cat(wk, dim=0))
                yield (f"model.layers.{layer_numer}.{qwen_moe_model_dict[sub_name][2]}",
                       torch.cat(wv, dim=0))
            elif sub_name in ["self_attention.linear_qkv.bias"]:
                bq, bk, bv = [], [], []
                for mlm_bias in gather_list:
                    mlm_bias = mlm_bias.reshape((kv_head_num_partion, kv_dims_partion))
                    q_end_idx = (dim * attn_head_num_partion // kv_head_num_partion)
                    bq.append(
                        mlm_bias.narrow(1, 0, q_end_idx).reshape((dim * attn_head_num_partion)))
                    bk.append(
                        mlm_bias.narrow(1, q_end_idx, dim).reshape((dim * kv_head_num_partion)))
                    bv.append(
                        mlm_bias.narrow(1, q_end_idx + dim, dim).reshape(
                            (dim * kv_head_num_partion)))
                yield (f"model.layers.{layer_numer}.{qwen_moe_model_dict[sub_name][0]}",
                       torch.cat(bq, dim=0))
                yield (f"model.layers.{layer_numer}.{qwen_moe_model_dict[sub_name][1]}",
                       torch.cat(bk, dim=0))
                yield (f"model.layers.{layer_numer}.{qwen_moe_model_dict[sub_name][2]}",
                       torch.cat(bv, dim=0))
            elif sub_name.startswith("mlp.experts.linear_fc1.weight"):
                name_prefix = "mlp.experts.linear_fc1.weight"
                local_expert_idx = int(sub_name.split(name_prefix)[-1])
                expert_idx = ep_rank * num_local_experts + local_expert_idx
                shard_size = shape[0] // 2
                w1 = [weight[:shard_size, :] for weight in gather_list]
                w2 = [weight[shard_size:, :] for weight in gather_list]

                yield (
                    f"model.layers.{layer_numer}.{qwen_moe_model_dict[name_prefix][0].format(expert_idx)}",
                    torch.cat(w1, dim=0))
                yield (
                    f"model.layers.{layer_numer}.{qwen_moe_model_dict[name_prefix][1].format(expert_idx)}",
                    torch.cat(w2, dim=0))
            elif sub_name.startswith("mlp.experts.linear_fc2.weight"):
                name_prefix = "mlp.experts.linear_fc2.weight"
                local_expert_idx = int(sub_name.split(name_prefix)[-1])
                expert_idx = ep_rank * num_local_experts + local_expert_idx
                yield (
                    f'model.layers.{layer_numer}.{qwen_moe_model_dict[name_prefix].format(expert_idx)}',
                    torch.cat(gather_list, dim=1))
            elif sub_name in [
                    "self_attention.linear_qkv.layer_norm_weight",
                    "self_attention.q_layernorm.weight",
                    "self_attention.k_layernorm.weight",
                    "pre_mlp_layernorm.weight",
                    "mlp.router.weight",
            ]:
                yield (f'model.layers.{layer_numer}.{qwen_moe_model_dict[sub_name]}',
                       gather_list[0].detach().clone())
            else:
                raise ValueError(f"unkonwn name {layer_idx=} {sub_name=}")


def update_moe_weights(
    pname,
    layer_weight,
    layer_offset,
    ep_rank,
    num_local_experts,
):
    _, layer_idx, sub_name = re.split(r'\.(\d+)\.', pname)
    layer_numer = int(layer_idx) + layer_offset
    shape = layer_weight.shape

    weight = torch.empty_like(layer_weight, device=torch.cuda.current_device())
    weight.copy_(layer_weight)

    if sub_name.startswith("mlp.experts.linear_fc1.weight"):
        name_prefix = "mlp.experts.linear_fc1.weight"
        local_expert_idx = int(sub_name.split(name_prefix)[-1])
        expert_idx = ep_rank * num_local_experts + local_expert_idx
        shard_size = shape[0] // 2
        w1 = weight[:shard_size, :]
        w2 = weight[shard_size:, :]

        return [
            (f"model.layers.{layer_numer}.{qwen_moe_model_dict[name_prefix][0].format(expert_idx)}",
             w1.detach().clone()),
            (f"model.layers.{layer_numer}.{qwen_moe_model_dict[name_prefix][1].format(expert_idx)}",
             w2.detach().clone())
        ]
    elif sub_name.startswith("mlp.experts.linear_fc2.weight"):
        name_prefix = "mlp.experts.linear_fc2.weight"
        local_expert_idx = int(sub_name.split(name_prefix)[-1])
        expert_idx = ep_rank * num_local_experts + local_expert_idx
        return [
            (f'model.layers.{layer_numer}.{qwen_moe_model_dict[name_prefix].format(expert_idx)}',
             weight.detach().clone())
        ]
    else:
        raise ValueError(f"Error {sub_name=} is not handle now")
