# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from torch import nn
from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
from deepspeed.model_implementations.transformers.ds_megatron_gpt import DeepSpeedMegatronGPTInference
from deepspeed.model_implementations.transformers.ds_opt import DeepSpeedOPTInference

import deepspeed.ops.transformer as transformer_inference
from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding, RMSNormalize
import torch
import gc
from deepspeed.accelerator import get_accelerator
import re


def load_model_with_checkpoint(r_module,
                               sd,
                               mp_replace,
                               ckpt_type,
                               ckpt_mp_size,
                               weight_quantizer=None,
                               rank=0,
                               container=None):
    error_msgs = []

    def prefix_check():
        # if keys start with 'model.', don't skip level 0 prefix
        for key in sd[0].keys():
            if re.match("^model[.]", key):
                return False
        return True

    skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix

    def transpose(data):
        with torch.no_grad():
            data = data.contiguous()
            data1 = data.transpose(-1, -2).reshape(-1)
            data.reshape(-1).copy_(data1)
            data1 = None
        return data.reshape(data.shape[-1], data.shape[-2])

    def load(module, prefix):
        args = (sd[0], prefix, {}, True, [], [], error_msgs)

        if hasattr(module, 'weight'):
            module.weight = mp_replace.copy(module.weight.data, sd[0][prefix + 'weight'])
        if prefix + 'bias' in sd[0].keys():
            if module.bias.data.is_meta:
                # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
                module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
                                                           requires_grad=module.bias.data.requires_grad)
            module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias'])
        args = None
        gc.collect()

    def load_transformer_layer(module, prefix):
        if ckpt_type == "tp":

            def load_parameters(module, prefix):
                for n, p in module.named_parameters():
                    if prefix + n in sd[0] and len(n.split('.')) == 1:
                        if type(sd[0][prefix + n]) is list:
                            tmp_data, scale = sd[0][prefix + n]
                            tmp_data = tmp_data
                            scale = scale.to(get_accelerator().current_device_name())
                            # set the quantizer number of groups using the checkpoint scale shape
                            weight_quantizer.num_groups = scale.shape[0]
                        else:
                            tmp_data = sd[0][prefix + n].to(get_accelerator().current_device_name())
                            scale = None
                        src_shape = tmp_data.shape
                        dst_shape = p.shape
                        inner_dim = 1 if tmp_data.dtype == torch.int8 else 0
                        outer_dim = 0 if tmp_data.dtype == torch.int8 else 1
                        if (len(src_shape) == 2 and len(dst_shape) == 2):
                            if (src_shape[inner_dim] == dst_shape[0] and src_shape[outer_dim] == dst_shape[1]):
                                if tmp_data.dtype != torch.int8:
                                    p = weight_quantizer.quantize(
                                        transpose(tmp_data) if weight_quantizer.q_int8 else tmp_data)
                                else:
                                    p = torch.nn.parameter.Parameter(tmp_data, requires_grad=False)
                                    p.scale = scale
                                setattr(module, n, p)
                            else:
                                dim = inner_dim if src_shape[inner_dim] != dst_shape[0] else outer_dim
                                dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1
                                if src_shape[dim] > dst_shape[dim1]:
                                    weight_partition = torch.split(tmp_data, dst_shape[dim1], dim=dim)[rank].to(
                                        get_accelerator().current_device_name())
                                    assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \
                                        '''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\
                                           Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!'''
                                    scale = scale.view(-1)[weight_quantizer.num_groups * (rank + 1):].reshape(
                                        weight_quantizer.num_groups, -1).contiguous()
                                else:
                                    assert tmp_data.dtype != torch.int8, \
                                        '''Merging of the checkpoints are not supported when using INT8 checkpoint! \
                                          Please use a as many GPUs as TP-size for the checkpoint'''
                                    all_data = [
                                        sd[j][prefix + n] if type(sd[j][prefix + n]) is list else sd[j][prefix + n].to(
                                            get_accelerator().current_device_name()) for j in range(len(sd))
                                    ]
                                    # Check if the weight tensor is for the QKV parameter
                                    if src_shape[1] == (3 * src_shape[0]) // ckpt_mp_size:
                                        qkv_size = src_shape[outer_dim] // 3
                                        src_split = [
                                            torch.split(src[0].data, qkv_size, dim=outer_dim) for src in all_data
                                        ]

                                        weight_partition = torch.cat([
                                            torch.cat([qkv_s[i] for qkv_s in src_split], axis=outer_dim)
                                            for i in range(len(src_split[0]))
                                        ],
                                                                     dim=dim)
                                    else:
                                        weight_partition = torch.cat([
                                            ad[0].to(get_accelerator().current_device_name())
                                            if type(ad) is list else ad for ad in all_data
                                        ],
                                                                     dim=dim)
                                    if tmp_data.dtype == torch.int8:
                                        scale = torch.cat(
                                            [ad[1].to(get_accelerator().current_device_name()) for ad in all_data],
                                            dim=dim)

                                if tmp_data.dtype != torch.int8:
                                    weight_partition = weight_quantizer.quantize(
                                        transpose(weight_partition), \
                                        parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \
                                        weight_quantizer.quantize(weight_partition)
                                else:
                                    weight_partition = torch.nn.parameter.Parameter(weight_partition,
                                                                                    requires_grad=False)
                                    weight_partition.scale = scale
                                setattr(module, n, weight_partition)
                        else:
                            if src_shape[0] == dst_shape[0]:
                                p.data.copy_(tmp_data)
                            else:
                                if src_shape[0] > dst_shape[0]:
                                    bias_split = torch.split(tmp_data, dst_shape[-1])[rank].to(
                                        get_accelerator().current_device_name()).contiguous()
                                    p.data.copy_(bias_split)
                                else:
                                    # Check if the weight tensor is for the QKV parameter
                                    if src_shape[0] == (3 * r_module.config.hidden_size) // ckpt_mp_size:
                                        qkv_size = src_shape[0] // 3
                                        src_split = [
                                            torch.split(sd[j][prefix + n], qkv_size, dim=0) for j in range(len(sd))
                                        ]

                                        p.data.copy_(
                                            torch.cat([
                                                torch.cat([qkv_s[i] for qkv_s in src_split], axis=0)
                                                for i in range(len(src_split[0]))
                                            ],
                                                      dim=0).to(get_accelerator().current_device_name()).contiguous())
                                    else:
                                        p.data.copy_(
                                            torch.cat([sd[j][prefix + n] for j in range(len(sd))],
                                                      dim=0).to(get_accelerator().current_device_name()).contiguous())

            load_parameters(module, prefix)
            for n, child in module.named_children():
                load_parameters(child, prefix + n + '.')
        else:
            container.load_params(module, sd[0], weight_quantizer, mp_replace, prefix)

    try:
        import transformers
        OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
        if hasattr(transformers.models, "llama"):
            LlamaRMSNorm = transformers.models.llama.modeling_llama.LlamaRMSNorm
        else:
            LlamaRMSNorm = None
    except:
        OPTLearnedPositionalEmbedding = None
    layer_policies = {
        nn.Linear: load,
        nn.Embedding: load,
        nn.LayerNorm: load,
        EmbeddingLayer: load,
        LinearLayer: load,
        Normalize: load,
        transformer_inference.DeepSpeedTransformerInference: load_transformer_layer,
        DeepSpeedBloomInference: load_transformer_layer,
        DeepSpeedGPTInference: load_transformer_layer,
        DeepSpeedBERTInference: load_transformer_layer,
        DeepSpeedMegatronGPTInference: load_transformer_layer,
        DeepSpeedOPTInference: load_transformer_layer,
        OPTLearnedPositionalEmbedding: load,
        OPTEmbedding: load,
        LlamaRMSNorm: load,
        RMSNormalize: load
    }

    all_ds_ids = {}

    def load_module_recursive(module, prefix='', level=0):
        for name, child in module.named_children():
            if child.__class__ in layer_policies:
                checking_key = prefix + name + '.'
                if not any(checking_key in item for item in sd[0].keys()):
                    if hasattr(child, 'weight') and \
                        (hasattr(child.weight, 'ds_id') and \
                        child.weight.ds_id in all_ds_ids):
                        prefix1 = all_ds_ids[child.weight.ds_id]
                        if child.__class__ is nn.Linear:
                            child = LinearLayer(weight=all_ds_ids[child.weight.ds_id])
                            setattr(module, name, child)
                    continue
                child_params = list(child.parameters())
                if len(child_params) > 0 and (child_params[0].numel() == 0 or child_params[0].is_meta):
                    if child.weight.is_meta:
                        ds_shape = child.weight.shape
                    else:
                        ds_shape = child.weight.ds_shape
                    if child.__class__ is nn.LayerNorm:
                        child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
                        setattr(module, name, child)
                    elif child.__class__ is nn.Linear:
                        child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias)
                        setattr(module, name, child)
                    elif child.__class__ is OPTLearnedPositionalEmbedding:
                        child = OPTEmbedding(weight_shape=ds_shape)
                        setattr(module, name, child)
                    elif child.__class__ is LlamaRMSNorm:
                        child = RMSNormalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.variance_epsilon)
                        setattr(module, name, child)
                    else:
                        ds_id = None
                        if hasattr(child.weight, 'ds_id'):
                            ds_id = child.weight.ds_id
                        child = EmbeddingLayer(weight_shape=ds_shape, dtype=child.weight.dtype)
                        if ds_id is not None:
                            all_ds_ids[ds_id] = child.weight
                        setattr(module, name, child)
                layer_policies[child.__class__](child, prefix + name + '.')
            else:
                load_module_recursive(
                    child,
                    prefix if (level == 0 and ckpt_type == 'pp') and skip_level_0_prefix else \
                    prefix + name + '.',
                    level + 1)

    load_module_recursive(r_module)

    embedding_weight = None

    for n, p in r_module.named_parameters():
        if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
            embedding_weight = p
    if embedding_weight is not None and r_module.lm_head.weight.is_meta:
        r_module.lm_head.weight = embedding_weight
    for sd_ in sd:
        del sd_
    sd = None
    gc.collect()
