import re
from typing import Dict, List, Union

import torch

from megatron.core import mpu

from gpatch.core.transformer.transformer_config import GpatchTransformerConfig
from gpatch.core import parallel_state as gmpu


                                                
                                     
                            
                                                      
                                        
def decoder_mcore_to_hf_weights(
    pname: str,
    gather_list: List[torch.Tensor],
    shape: torch.Size,
    layer_offset: int,
    model_name_dict: Dict[str, Union[str, List[str]]],
    model_config: GpatchTransformerConfig,
    decoder_prefix: str = "",
    hf_qkv_format: str = "separate",
):
    """Converts Megatron-Core decoder weights to HuggingFace weights."""
    tp_size = mpu.get_tensor_model_parallel_world_size()
    hidden_size = model_config.hidden_size
    num_attn_heads = model_config.num_attention_heads
    num_kv_heads = model_config.num_query_groups
    dim = model_config.kv_channels
    attn_head_num_partion = num_attn_heads // tp_size
    kv_head_num_partion = num_kv_heads // tp_size
    kv_dims_partion = (2 * dim + (dim * num_attn_heads // num_kv_heads))

    cat_dim0_keys = []
    if model_config.gated_linear_unit:
        cat_dim1_keys = [
            "self_attention.linear_proj.weight",
            "mlp.linear_fc2.weight",
        ]
        mlp_merge_keys = [
            "mlp.linear_fc1.weight",
            "mlp.linear_fc1.bias",
        ]
    else:
        cat_dim0_keys.extend([
            "mlp.linear_fc1.weight",
            "mlp.linear_fc1.bias",
        ])
        cat_dim1_keys = [
            "self_attention.linear_proj.weight",
            "mlp.linear_fc2.weight",
        ]
        mlp_merge_keys = []

    _, layer_idx, sub_name = re.split(r'\.(\d+)\.', pname)
    layer_numer = int(layer_idx) + layer_offset
    if sub_name in cat_dim0_keys:
        hf_w_name = f'{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name]}'
        hf_w = torch.cat(gather_list, dim=0)
        yield (hf_w_name, hf_w)
    elif sub_name in cat_dim1_keys:
        hf_w_name = f'{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name]}'
        hf_w = torch.cat(gather_list, dim=1)
        yield (hf_w_name, hf_w)
    elif sub_name in mlp_merge_keys:
        shard_size = shape[0] // 2
        w1 = [weight[:shard_size] for weight in gather_list]
        w2 = [weight[shard_size:] for weight in gather_list]
        yield (f"{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name][0]}", torch.cat(w1,
                                                                                           dim=0))
        yield (f"{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name][1]}", torch.cat(w2,
                                                                                           dim=0))

    elif sub_name in ["self_attention.linear_qkv.weight"]:
        if isinstance(model_name_dict[sub_name], list):
            assert hf_qkv_format in ["separate"]
            wq, wk, wv = [], [], []
            for weight in gather_list:
                weight = weight.reshape((kv_head_num_partion, kv_dims_partion, -1))
                q_end_idx = (dim * attn_head_num_partion // kv_head_num_partion)
                wq.append(weight.narrow(1, 0, q_end_idx).reshape((dim * attn_head_num_partion, -1)))
                wk.append(weight.narrow(1, q_end_idx, dim).reshape((dim * kv_head_num_partion, -1)))
                wv.append(
                    weight.narrow(1, q_end_idx + dim, dim).reshape((dim * kv_head_num_partion, -1)))
            yield (f"{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name][0]}",
                   torch.cat(wq, dim=0))
            yield (f"{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name][1]}",
                   torch.cat(wk, dim=0))
            yield (f"{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name][2]}",
                   torch.cat(wv, dim=0))
        else:
            assert isinstance(model_name_dict[sub_name], str)
            assert hf_qkv_format in ["deinterleave"]
            assert attn_head_num_partion == kv_head_num_partion
            tmp_list = []
            for weight in gather_list:
                tmp_list.append(weight.reshape((kv_head_num_partion, 3, -1)))
            hf_w = torch.cat(tmp_list, dim=0).transpose(0, 1).reshape(-1, hidden_size).contiguous()
            yield (f"{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name]}", hf_w)
    elif sub_name in ["self_attention.linear_qkv.bias"]:
        if isinstance(model_name_dict[sub_name], list):
            assert hf_qkv_format in ["separate"]
            bq, bk, bv = [], [], []
            for mlm_bias in gather_list:
                mlm_bias = mlm_bias.reshape((kv_head_num_partion, kv_dims_partion))
                q_end_idx = (dim * attn_head_num_partion // kv_head_num_partion)
                bq.append(mlm_bias.narrow(1, 0, q_end_idx).reshape((dim * attn_head_num_partion)))
                bk.append(mlm_bias.narrow(1, q_end_idx, dim).reshape((dim * kv_head_num_partion)))
                bv.append(
                    mlm_bias.narrow(1, q_end_idx + dim, dim).reshape((dim * kv_head_num_partion)))
            yield (f"{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name][0]}",
                   torch.cat(bq, dim=0))
            yield (f"{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name][1]}",
                   torch.cat(bk, dim=0))
            yield (f"{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name][2]}",
                   torch.cat(bv, dim=0))
        else:
            assert isinstance(model_name_dict[sub_name], str)
            assert hf_qkv_format in ["deinterleave"]
            tmp_list = []
            for mlm_bias in gather_list:
                mlm_bias = mlm_bias.reshape((kv_head_num_partion, kv_dims_partion))             
                mlm_bias = mlm_bias.view(kv_head_num_partion, 3, -1)
                tmp_list.append(mlm_bias)
            hf_w = torch.cat(tmp_list, dim=0).transpose(0, 1).reshape(-1).contiguous()
            yield (f"{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name]}", hf_w)
    else:
        hf_w_name = f'{decoder_prefix}.{layer_numer}.{model_name_dict[sub_name]}'
        hf_w = gather_list[0]
        yield (hf_w_name, hf_w)


def gather_weight_from_tp(weight: torch.Tensor, tp_size, dim=0):
    is_cuda = weight.is_cuda
    if not is_cuda:
        weight = weight.cuda()
    if mpu.get_tensor_model_parallel_rank() == 0:
        gather_list = [torch.empty_like(weight) for _ in range(tp_size)]
    else:
        gather_list = None
    dst = gmpu.update_weights_gather_dst_rank()
    torch.distributed.gather(weight,
                             gather_list,
                             dst=dst,
                             group=mpu.get_tensor_model_parallel_group())

    res = torch.cat(gather_list, dim=dim)
    if not is_cuda:
        res = res.cpu()
    return res


def merge_hf_lora_weight(
    pname: str,
    param: torch.Tensor,
    state_dict: Dict[str, torch.Tensor],
    model_config: GpatchTransformerConfig,
):
    tp_size = mpu.get_tensor_model_parallel_world_size()
    tp_rank = mpu.get_tensor_model_parallel_rank()
         
    if ".mlp.linear_fc1.weight" in pname:
        a_name = pname.replace(".mlp.linear_fc1.weight", ".mlp.linear_fc1.lora_a.weight")
        if model_config.gated_linear_unit:
            b_names = [
                pname.replace(".mlp.linear_fc1.weight", ".mlp.linear_fc1.lora_b.0.weight"),
                pname.replace(".mlp.linear_fc1.weight", ".mlp.linear_fc1.lora_b.1.weight"),
            ]
        else:
            b_names = [
                pname.replace(".mlp.linear_fc1.weight", ".mlp.linear_fc1.lora_b.weight"),
            ]
                     
        assert a_name in state_dict
        assert b_names[0] in state_dict
        lora_a = gather_weight_from_tp(state_dict[a_name], tp_size, 0)
        lora_b0 = gather_weight_from_tp(state_dict[b_names[0]], tp_size, 0)
        if model_config.gated_linear_unit:
            assert b_names[1] in state_dict
            lora_b1 = gather_weight_from_tp(state_dict[b_names[1]], tp_size, 0)
                                                    
            list_lora_a = lora_a.chunk(2, 0)
            lora_w0 = torch.matmul(lora_b0, list_lora_a[0]).chunk(tp_size, 0)[tp_rank]
            lora_w1 = torch.matmul(lora_b1, list_lora_a[1]).chunk(tp_size, 0)[tp_rank]
                                              
            param = param.clone() + torch.cat([lora_w0, lora_w1], dim=0)
        else:
            lora_w0 = torch.matmul(lora_b0, lora_a).chunk(tp_size, 0)[tp_rank]
            param = param.clone() + lora_w0
    elif ".self_attention.linear_qkv.weight" in pname:
        a_name = pname.replace(".self_attention.linear_qkv.weight",
                               ".self_attention.linear_qkv.lora_a.weight")
        b_names = [
            pname.replace(".self_attention.linear_qkv.weight",
                          ".self_attention.linear_qkv.lora_b.0.weight"),
            pname.replace(".self_attention.linear_qkv.weight",
                          ".self_attention.linear_qkv.lora_b.1.weight"),
            pname.replace(".self_attention.linear_qkv.weight",
                          ".self_attention.linear_qkv.lora_b.2.weight"),
        ]
        assert a_name in state_dict
        lora_a = gather_weight_from_tp(state_dict[a_name], tp_size, 0)
        list_lora_b = []
        for i in range(len(b_names)):
            assert b_names[i] in state_dict
            list_lora_b.append(gather_weight_from_tp(state_dict[b_names[i]], tp_size, 0))
        list_lora_a = lora_a.chunk(3, 0)
        list_lora_w = []
        for lora_aw, lora_bw in zip(list_lora_a, list_lora_b):
            list_lora_w.append(torch.matmul(lora_bw, lora_aw))

        num_query_head = model_config.num_attention_heads
        num_kv_head = model_config.num_query_groups
        dim = model_config.kv_channels
        hidden_size = model_config.hidden_size
        assert num_query_head % num_kv_head == 0
        lora_w = torch.cat(
            [
                list_lora_w[0].reshape(
                    (num_kv_head, dim * num_query_head // num_kv_head, hidden_size)),
                list_lora_w[1].reshape((num_kv_head, dim, hidden_size)),
                list_lora_w[2].reshape((num_kv_head, dim, hidden_size)),
            ],
            dim=1,
        ).reshape(-1, hidden_size)
        lora_w = lora_w.chunk(tp_size, 0)[tp_rank]
        param = param.clone() + lora_w
    elif ".mlp.linear_fc2.weight" in pname or ".self_attention.linear_proj.weight" in pname:
        if ".mlp.linear_fc2.weight" in pname:
            replace_str = ".mlp.linear_fc2"
        elif ".self_attention.linear_proj.weight" in pname:
            replace_str = ".self_attention.linear_proj"
        a_name = pname.replace(f"{replace_str}.weight", f"{replace_str}.lora_a.weight")
        b_name = pname.replace(f"{replace_str}.weight", f"{replace_str}.lora_b.weight")
        lora_a = gather_weight_from_tp(state_dict[a_name], tp_size, 1)
        lora_b = gather_weight_from_tp(state_dict[b_name], tp_size, 1)
        lora_w = torch.matmul(lora_b, lora_a).chunk(tp_size, 1)[tp_rank]
        param = param.clone() + lora_w
    else:
        param = param.clone()

    return param
