
from typing import Dict
import torch
import torch.nn as nn

from vllm.model_executor.layers.linear import *
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
from vllm.model_executor.layers.activation import ScaledActivation
from vllm.model_executor.models import ModelRegistry



def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:

    assert param.size() == loaded_weight.size(
    ), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format(
        param.size(), loaded_weight.size())
    assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"

    param.data = loaded_weight.data


def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:

    assert param.size() == loaded_weight.size()
    assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"

    param.data = loaded_weight.data


def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
    params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
    for name, loaded_weight in actor_weights.items():
        if "lm_head.weight" in name:

            continue
        if ".attn.bias" in name or ".attn.masked_bias" in name:

            continue
        if not name.startswith("transformer."):
            name = "transformer." + name
        param = params_dict[name]

        for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
            if conv1d_weight_name not in name:
                continue
            if not name.endswith(".weight"):
                continue

            loaded_weight = loaded_weight.t()
        weight_loader = getattr(param, "weight_loader", default_weight_loader)
        weight_loader(param, loaded_weight)


def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:

    params_dict = dict(vllm_model.named_parameters())
    for name, loaded_weight in actor_weights.items():
        if "rotary_emb.inv_freq" in name:
            continue
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)


def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
    params_mapping = [

        ("embedding.word_embeddings", "model.embed_tokens"),
        ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
        ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
        ("self_attention.linear_qkv", "self_attn.qkv_proj"),
        ("self_attention.linear_qkv", "self_attn.qkv_proj"),
        ("self_attention.linear_proj", 'self_attn.o_proj'),
        ('pre_mlp_layernorm', 'post_attention_layernorm'),
        ('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'),
        ('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'),
        ('mlp.linear_fc1', 'mlp.gate_up_proj'),
        ('mlp.linear_fc2', 'mlp.down_proj'),
        ('decoder.final_layernorm', 'model.norm'),
        ('output_layer', 'lm_head'),
    ]

    params_dict = dict(vllm_model.named_parameters())
    for name, loaded_weight in actor_weights.items():
        name = _replace_name(name, params_mapping)
        if name.endswith('.bias') and name not in params_dict:
            continue
        if "rotary_emb.inv_freq" in name:
            continue
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)


def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
    params_mapping = [

        ("embedding.word_embeddings", "model.embed_tokens"),
        ("self_attention.linear_qkv", "self_attn.qkv_proj"),
        ("self_attention.linear_proj", 'self_attn.o_proj'),
        (
            'input_layernorm',
            'input_layernorm',
        ),
        ('pre_mlp_layernorm', 'post_attention_layernorm'),
        ('mlp.linear_fc1', 'mlp.gate_up_proj'),
        ('mlp.linear_fc2', 'mlp.down_proj'),
        ('decoder.final_layernorm', 'model.norm'),
        ('output_layer', 'lm_head'),
    ]

    params_dict = dict(vllm_model.named_parameters())
    for name, loaded_weight in actor_weights.items():
        name = _replace_name(name, params_mapping)
        if name.endswith('.bias') and name not in params_dict:
            continue
        if "rotary_emb.inv_freq" in name:
            continue
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)


def _replace_name(megatron_name, name_mapping):
    for m_name, v_name in name_mapping:
        if m_name not in megatron_name:
            continue
        if 'layers' in megatron_name:
            megatron_name = megatron_name.replace('decoder', 'model')
            megatron_name_list = megatron_name.split('.')
            if 'layer_norm_weight' in megatron_name_list or 'layer_norm_bias' in megatron_name_list:
                param_name_list = megatron_name_list[:3]
                param_name_list.append(v_name)
                param_name = '.'.join(param_name_list)
            else:
                param_name_list = megatron_name_list[:3]
                weight_or_bias = megatron_name_list[-1]
                param_name_list.append(v_name)
                param_name_list.append(weight_or_bias)
                param_name = '.'.join(param_name_list)
            return param_name
        else:
            param_name = megatron_name.replace(m_name, v_name)
            return param_name


def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
    params_mapping = [

        ("embedding.word_embeddings", "model.embed_tokens"),
        ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
        ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
        ("self_attention.linear_qkv", "self_attn.qkv_proj"),
        ("self_attention.linear_qkv", "self_attn.qkv_proj"),
        ("self_attention.linear_proj", 'self_attn.o_proj'),
        ('pre_mlp_layernorm', 'post_attention_layernorm'),
        ('mlp.linear_fc1.layer_norm_weight', 'post_attention_layernorm.weight'),
        ('mlp.linear_fc1.layer_norm_bias', 'post_attention_layernorm.bias'),
        ('mlp.linear_fc1', 'mlp.gate_up_proj'),
        ('mlp.linear_fc2', 'mlp.down_proj'),
        ('decoder.final_layernorm', 'model.norm'),
        ('output_layer', 'lm_head'),
    ]

    params_dict = dict(vllm_model.named_parameters())
    for name, loaded_weight in actor_weights.items():
        name = _replace_name(name, params_mapping)
        if name.endswith('.bias') and name not in params_dict:
            continue
        if "rotary_emb.inv_freq" in name:
            continue
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)


def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
    params_mapping = [

        ("embedding.word_embeddings", "model.embed_tokens"),
        ("self_attention.linear_qkv", "self_attn.qkv_proj"),
        ("self_attention.linear_proj", 'self_attn.o_proj'),
        (
            'input_layernorm',
            'input_layernorm',
        ),
        ('pre_mlp_layernorm', 'post_attention_layernorm'),
        ('mlp.linear_fc1', 'mlp.gate_up_proj'),
        ('mlp.linear_fc2', 'mlp.down_proj'),
        ('decoder.final_layernorm', 'model.norm'),
        ('output_layer', 'lm_head'),
    ]

    params_dict = dict(vllm_model.named_parameters())
    for name, loaded_weight in actor_weights.items():
        name = _replace_name(name, params_mapping)
        if name.endswith('.bias') and name not in params_dict:
            continue
        if "rotary_emb.inv_freq" in name:
            continue
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)


def _replace_name(megatron_name, name_mapping):
    for m_name, v_name in name_mapping:
        if m_name not in megatron_name:
            continue
        if 'layers' in megatron_name:
            megatron_name = megatron_name.replace('decoder', 'model')
            megatron_name_list = megatron_name.split('.')
            if 'layer_norm_weight' in megatron_name_list or 'layer_norm_bias' in megatron_name_list:
                param_name_list = megatron_name_list[:3]
                param_name_list.append(v_name)
                param_name = '.'.join(param_name_list)
            else:
                param_name_list = megatron_name_list[:3]
                weight_or_bias = megatron_name_list[-1]
                param_name_list.append(v_name)
                param_name_list.append(weight_or_bias)
                param_name = '.'.join(param_name_list)
            return param_name
        else:
            param_name = megatron_name.replace(m_name, v_name)
            return param_name


def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:

    params_dict = dict(vllm_model.named_parameters())
    for name, loaded_weight in actor_weights.items():
        if "rotary_emb.inv_freq" in name:
            continue
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)


__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = {
    ColumnParallelLinear: parallel_weight_loader,
    MergedColumnParallelLinear: parallel_weight_loader,
    QKVParallelLinear: parallel_weight_loader,
    RowParallelLinear: parallel_weight_loader,
    VocabParallelEmbedding: parallel_weight_loader,
    ParallelLMHead: parallel_weight_loader

}


__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = {
    'GPT2LMHeadModel': gpt2_weight_loader,
    'LlamaForCausalLM': llama_megatron_core_te_weight_loader,
    'LLaMAForCausalLM': llama_megatron_core_te_weight_loader,
    'MistralForCausalLM': mistral_megatron_weight_loader,
}



def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module):
    weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__)
    weight_loader(actor_weights, vllm_model)

    vllm_model = vllm_model.cuda()


def _get_model_weight_loader(arch: str):
    if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__:
        return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch]
    raise ValueError(f"Model architectures {arch} are not supported for now. "
                     f"Supported architectures: {ModelRegistry.get_supported_archs()}")


def update_megatron_weight_loader():
    for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items():
        layer_class.weight_loader = weight_loader
    VocabParallelEmbedding.__init__ = vocab_init



DEFAULT_VOCAB_PADDING_SIZE = 64


def vocab_init(self,
               num_embeddings: int,
               embedding_dim: int,
               params_dtype: Optional[torch.dtype] = None,
               org_num_embeddings: Optional[int] = None,
               padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
    super(VocabParallelEmbedding, self).__init__()


    self.num_embeddings = num_embeddings
    self.org_vocab_size = org_num_embeddings or num_embeddings


    self.embedding_dim = embedding_dim
    if params_dtype is None:
        params_dtype = torch.get_default_dtype()
    self.tp_size = get_tensor_model_parallel_world_size()

    from megatron.core.tensor_parallel.utils import VocabUtility
    self.vocab_start_index, self.vocab_end_index = (VocabUtility.vocab_range_from_global_vocab_size(
        self.num_embeddings, get_tensor_model_parallel_rank(), self.tp_size))
    self.num_embeddings_per_partition = (self.vocab_end_index - self.vocab_start_index)
    self.weight = Parameter(
        torch.empty(
            self.num_embeddings_per_partition,
            self.embedding_dim,

            dtype=params_dtype))
    set_weight_attrs(self.weight, {"parallel_dim": 0, "weight_loader": self.weight_loader})
