

from typing import Dict
import torch
import torch.nn as nn



def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None:

    assert param.size() == loaded_weight.size(
    ), 'the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}'.format(
        param.size(), loaded_weight.size())
    assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"

    param.data = loaded_weight.data


def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:

    assert param.size() == loaded_weight.size()
    assert param.data.dtype == loaded_weight.data.dtype, "if we want to shared weights, the data type should also be the same"

    param.data = loaded_weight.data


def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:
    params_dict = dict(vllm_model.named_parameters(remove_duplicate=False))
    for name, loaded_weight in actor_weights.items():
        if "lm_head.weight" in name:

            continue
        if ".attn.bias" in name or ".attn.masked_bias" in name:

            continue
        if not name.startswith("transformer."):
            name = "transformer." + name
        param = params_dict[name]

        for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
            if conv1d_weight_name not in name:
                continue
            if not name.endswith(".weight"):
                continue

            loaded_weight = loaded_weight.t()
        weight_loader = getattr(param, "weight_loader", default_weight_loader)
        weight_loader(param, loaded_weight)


def llama_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:

    prefix = '0.module.module.'
    params_dict = dict(vllm_model.named_parameters())
    for name, loaded_weight in actor_weights.items():
        if name[:len(prefix)] == prefix:
            name = name[len(prefix):]
        if "rotary_emb.inv_freq" in name:
            continue
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)


def mistral_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module:

    prefix = '0.module.module.'
    params_dict = dict(vllm_model.named_parameters())
    for name, loaded_weight in actor_weights.items():
        if name[:len(prefix)] == prefix:
            name = name[len(prefix):]
        if "rotary_emb.inv_freq" in name:
            continue
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)
