"""
Usage:
python merge_llama_with_chinese_lora.py \
    --base_model path/to/llama/model \
    --lora_model path/to/first/lora/model [path/to/second/lora/model] \
    --output_type [pth|huggingface] \
    --output_dir path/to/output/dir
"""
import argparse
import json
import os

import gc
import torch
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
)
from torch_npu.contrib import transfer_to_npu


def merge_layernorm_cycle(desnet_key, source_net_key, source_net_ckpt, desnet_ckpt, mode='from_small'):
    desnet_weight_size = desnet_ckpt[desnet_key].shape[0]
    if mode == 'from_small':
        small_model_weight_size = source_net_ckpt[source_net_key].shape[0]
        scaled_value = torch.cat(
            (source_net_ckpt[source_net_key], source_net_ckpt[source_net_key][:(desnet_weight_size - small_model_weight_size)]),
            dim=0)
        return scaled_value
    elif mode == 'from_big':
        narrowed_value = source_net_ckpt[source_net_key][:desnet_weight_size]
        return narrowed_value


def merge_att_mlp_layer_cycle(desnet_key, source_net_key, source_net_ckpt, desnet_ckpt, bias=False, mode='from_small'):
    if not bias:
        desnet_weight_size_row = desnet_ckpt[desnet_key].shape[0]
        desnet_weight_size_column = desnet_ckpt[desnet_key].shape[1]
        if mode == 'from_small':
            small_model_weight_size_row = source_net_ckpt[source_net_key].shape[0]
            small_model_weight_size_column = source_net_ckpt[source_net_key].shape[1]
            temp_value = torch.cat(
                (source_net_ckpt[source_net_key],
                 source_net_ckpt[source_net_key][:(desnet_weight_size_row - small_model_weight_size_row), ]),
                dim=0)
            scaled_value = torch.cat(
                (temp_value, temp_value[:, :(desnet_weight_size_column - small_model_weight_size_column)]),
                dim=1)
            return scaled_value
        elif mode == 'from_big':
            narrowed_value = source_net_ckpt[source_net_key][:desnet_weight_size_row, :desnet_weight_size_column]
            return narrowed_value

    elif bias:
        desnet_weight_size = desnet_ckpt[desnet_key].shape[0]
        if mode == 'from_small':
            small_model_weight_size = source_net_ckpt[source_net_key].shape[0]
            scaled_value = torch.cat(
                (source_net_ckpt[source_net_key],
                 source_net_ckpt[source_net_key][:(desnet_weight_size - small_model_weight_size)]),
                dim=0)
            return scaled_value
        elif mode == 'from_big':
            narrowed_value = source_net_ckpt[source_net_key][:desnet_weight_size]
            return narrowed_value


parser = argparse.ArgumentParser()
# pretrained model setting
parser.add_argument('--big_model_path', default=None, required=True,
                    type=str, help="Please specify the path of the bigger model checkpoint")
parser.add_argument('--small_model_path', default=None, required=True,
                    type=str, help="Please specify the path of the smaller model checkpoint")
# descendant model setting
parser.add_argument('--model_base_config', default=None, required=True,
                    type=str, help="The path of base config of llm model [GPT2/llama/llama2/opt]")
parser.add_argument('--embeding_dim', default=None, required=True,
                    type=int, help="The number of embeding dim in the llm model")
parser.add_argument('--head_num', default=None, required=True,
                    type=int, help="The number of head in each transformer layer")
parser.add_argument('--layer_num', default=None, required=True,
                    type=int, help="The number of layers in the llm model")
parser.add_argument('--start_num', default=3, required=True, type=int)
parser.add_argument(
        "--trust_remote_code",
        type=bool,
        default=True,
        help=(
            "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
            "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
            "execute code present on the Hub on your local machine."
        ))
# merging setting
parser.add_argument('--alpha', default=None, required=True,
                    type=float, help="Coefficients for merging small model parameters")
# saving setting
parser.add_argument('--output_type', default='pth', choices=['pth', 'huggingface'], type=str,
                    help="save the merged model in pth or huggingface format.")
parser.add_argument('--output_dir', default='./', type=str)


if __name__=='__main__':

    args = parser.parse_args()

    big_model_path = args.big_model_path
    small_model_path = args.small_model_path
    middle_model_config = AutoConfig.from_pretrained(
        args.model_base_config,
        n_embd=args.embeding_dim,
        n_head=args.head_num,
        n_layer=args.layer_num,
        trust_remote_code=args.trust_remote_code,
        )
    assert middle_model_config.n_embd == args.embeding_dim and middle_model_config.n_head == args.head_num and middle_model_config.n_layer == args.layer_num

    middle_model_setting = "alpha_"+str(args.alpha)+"_NEmbed_"+str(args.embeding_dim)+"_NHead_"+str(args.head_num)+"_NLayer_"+str(args.layer_num)
    output_dir = os.path.join(args.output_dir, middle_model_setting)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    output_type = args.output_type

    print(f"Big Model: {big_model_path}")
    print(f"Small Model: {small_model_path}")

    # define middle model
    desnet = AutoModelForCausalLM.from_config(middle_model_config,
                                                 trust_remote_code=args.trust_remote_code)
    
    # get the checkpoint of the big model and smaller model
    big_model_ckpt = torch.load(args.big_model_path, map_location='cpu')
    big_model_layer = int(list(big_model_ckpt.items())[-4][0].split('.')[2])+1  # get the index of layers in the big model
    small_model_ckpt = torch.load(args.small_model_path, map_location='cpu')
    small_model_layer = int(list(small_model_ckpt.items())[-4][0].split('.')[2])+1  # get the index of layers in the small model
    # get the checkpoint of the middle model
    desnet_ckpt = desnet.state_dict()
    scaled_small_model_ckpt = {}  # 从较小的辅助模型扩充而来
    narrowed_big_model_ckpt = {}  # 从较大的辅助模型缩小而来
    ################################## Merging ##################################
    # Embedding and position layer
    for key in ['transformer.wte.weight', 'transformer.wpe.weight']:
        small_model_weight = small_model_ckpt[key]
        small_model_weight_size = small_model_weight.shape[1]

        big_model_weight = big_model_ckpt[key]
        big_model_weight_size = big_model_weight.shape[1]

        desnet_weight = desnet_ckpt[key]
        desnet_weight_size = desnet_weight.shape[1]

        scaled_small_model_ckpt[key] = torch.cat(
                (small_model_weight, small_model_weight[:, :(desnet_weight_size-small_model_weight_size)]),
                dim=1)
        narrowed_big_model_ckpt[key] = big_model_weight[:, :desnet_weight_size]

    ######################### Transformer Layer: from big model ########################
    desnet_layer_id_list = list(range(args.layer_num))
    big_model_layer_id_list = [i for i in range(args.start_num)]+ [i for i in range(args.start_num, args.layer_num-args.start_num)] + [(big_model_layer - i) for i in range(args.start_num, 0, -1)]
    print(f"desnet_layer_id_list: {desnet_layer_id_list}")
    print(f'big_model_layer_id_list: {big_model_layer_id_list}')
    assert len(desnet_layer_id_list) == len(big_model_layer_id_list), "The length of the two model layers should be equal!"
    for desnet_layer_id, big_model_layer_id in zip(desnet_layer_id_list, big_model_layer_id_list):
        mode = 'from_big'
        # Layer Norm 1 Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".ln_1.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".ln_1.bias"
        source_net_weight_key ="transformer.h." + str(big_model_layer_id) + ".ln_1.weight"
        source_net_bias_key = "transformer.h." + str(big_model_layer_id) + ".ln_1.bias"
        narrowed_big_model_ckpt[desnet_weight_key] = \
            merge_layernorm_cycle(desnet_weight_key, source_net_weight_key, big_model_ckpt, desnet_ckpt, mode)
        narrowed_big_model_ckpt[desnet_bias_key] = \
            merge_layernorm_cycle(desnet_bias_key, source_net_bias_key, big_model_ckpt, desnet_ckpt, mode)

        # MHA Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".attn.c_attn.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".attn.c_attn.bias"
        source_net_weight_key = "transformer.h." + str(big_model_layer_id) + ".attn.c_attn.weight"
        source_net_bias_key = "transformer.h." + str(big_model_layer_id) + ".attn.c_attn.bias"
        narrowed_big_model_ckpt[desnet_weight_key] = \
            merge_att_mlp_layer_cycle(desnet_weight_key, source_net_weight_key, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)
        narrowed_big_model_ckpt[desnet_bias_key] = \
            merge_att_mlp_layer_cycle(desnet_bias_key, source_net_bias_key, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)
        
        # Projection Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".attn.c_proj.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".attn.c_proj.bias"
        source_net_weight_key = "transformer.h." + str(big_model_layer_id) + ".attn.c_proj.weight"
        source_net_bias_key ="transformer.h." + str(big_model_layer_id) + ".attn.c_proj.bias"
        narrowed_big_model_ckpt[desnet_weight_key] = \
            merge_att_mlp_layer_cycle(desnet_weight_key, source_net_weight_key, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)
        narrowed_big_model_ckpt[desnet_bias_key] = \
            merge_att_mlp_layer_cycle(desnet_bias_key, source_net_bias_key, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)

        # Layer Norm 2 Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".ln_2.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".ln_2.bias"
        source_net_weight_key = "transformer.h." + str(big_model_layer_id) + ".ln_2.weight"
        source_net_bias_key = "transformer.h." + str(big_model_layer_id) + ".ln_2.bias"
        narrowed_big_model_ckpt[desnet_weight_key] = \
            merge_layernorm_cycle(desnet_weight_key, source_net_weight_key, big_model_ckpt, desnet_ckpt, mode)
        narrowed_big_model_ckpt[desnet_bias_key] = \
            merge_layernorm_cycle(desnet_bias_key, source_net_bias_key, big_model_ckpt, desnet_ckpt, mode)

        # MLP Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".mlp.c_fc.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".mlp.c_fc.bias"
        source_net_weight_key = "transformer.h." + str(big_model_layer_id) + ".mlp.c_fc.weight"
        source_net_bias_key = "transformer.h." + str(big_model_layer_id) + ".mlp.c_fc.bias"
        narrowed_big_model_ckpt[desnet_weight_key] = \
            merge_att_mlp_layer_cycle(desnet_weight_key, source_net_weight_key, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)
        narrowed_big_model_ckpt[desnet_bias_key] = \
            merge_att_mlp_layer_cycle(desnet_bias_key, source_net_bias_key, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)

        # Projection Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".mlp.c_proj.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".mlp.c_proj.bias"
        source_net_weight_key = "transformer.h." + str(big_model_layer_id) + ".mlp.c_proj.weight"
        source_net_bias_key = "transformer.h." + str(big_model_layer_id) + ".mlp.c_proj.bias"
        narrowed_big_model_ckpt[desnet_weight_key] = \
            merge_att_mlp_layer_cycle(desnet_weight_key, source_net_weight_key, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)
        narrowed_big_model_ckpt[desnet_bias_key] = \
            merge_att_mlp_layer_cycle(desnet_bias_key, source_net_bias_key, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)
    
    
    ######################### Transformer Layer: from small model ########################
    desnet_layer_id_list = list(range(args.layer_num))
    step = int((args.layer_num - 2 * args.start_num)) // int((small_model_layer - 2 * args.start_num))
    print(f'Num step: {step}')
    small_model_layer_id_list = [i for i in range(args.start_num)]+ [i for i in range(args.start_num, small_model_layer - 3) for _ in range(step)] + [(small_model_layer - i) for i in range(args.start_num, 0, -1)]
    print(f"desnet_layer_id_list: {desnet_layer_id_list}")
    print(f'small_model_layer_id_list: {small_model_layer_id_list}')
    assert len(desnet_layer_id_list) == len(small_model_layer_id_list), "The length of the two model layers should be equal!"
    for desnet_layer_id, small_model_layer_id in zip(desnet_layer_id_list, small_model_layer_id_list):
        mode = 'from_small'
        # Layer Norm 1 Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".ln_1.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".ln_1.bias"
        source_net_weight_key ="transformer.h." + str(small_model_layer_id) + ".ln_1.weight"
        source_net_bias_key = "transformer.h." + str(small_model_layer_id) + ".ln_1.bias"
        scaled_small_model_ckpt[desnet_weight_key] = \
            merge_layernorm_cycle(desnet_weight_key, source_net_weight_key, small_model_ckpt, desnet_ckpt, mode)
        scaled_small_model_ckpt[desnet_bias_key] = \
            merge_layernorm_cycle(desnet_bias_key, source_net_bias_key, small_model_ckpt, desnet_ckpt, mode)

        # MHA Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".attn.c_attn.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".attn.c_attn.bias"
        source_net_weight_key = "transformer.h." + str(small_model_layer_id) + ".attn.c_attn.weight"
        source_net_bias_key = "transformer.h." + str(small_model_layer_id) + ".attn.c_attn.bias"
        scaled_small_model_ckpt[desnet_weight_key] = \
            merge_att_mlp_layer_cycle(desnet_weight_key, source_net_weight_key, small_model_ckpt, desnet_ckpt, bias=False, mode=mode)
        scaled_small_model_ckpt[desnet_bias_key] = \
            merge_att_mlp_layer_cycle(desnet_bias_key, source_net_bias_key, small_model_ckpt, desnet_ckpt, bias=True, mode=mode)
        
        # Projection Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".attn.c_proj.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".attn.c_proj.bias"
        source_net_weight_key = "transformer.h." + str(small_model_layer_id) + ".attn.c_proj.weight"
        source_net_bias_key ="transformer.h." + str(small_model_layer_id) + ".attn.c_proj.bias"
        scaled_small_model_ckpt[desnet_weight_key] = \
            merge_att_mlp_layer_cycle(desnet_weight_key, source_net_weight_key, small_model_ckpt, desnet_ckpt, bias=False, mode=mode)
        scaled_small_model_ckpt[desnet_bias_key] = \
            merge_att_mlp_layer_cycle(desnet_bias_key, source_net_bias_key, small_model_ckpt, desnet_ckpt, bias=True, mode=mode)

        # Layer Norm 2 Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".ln_2.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".ln_2.bias"
        source_net_weight_key = "transformer.h." + str(small_model_layer_id) + ".ln_2.weight"
        source_net_bias_key = "transformer.h." + str(small_model_layer_id) + ".ln_2.bias"
        scaled_small_model_ckpt[desnet_weight_key] = \
            merge_layernorm_cycle(desnet_weight_key, source_net_weight_key, small_model_ckpt, desnet_ckpt, mode)
        scaled_small_model_ckpt[desnet_bias_key] = \
            merge_layernorm_cycle(desnet_bias_key, source_net_bias_key, small_model_ckpt, desnet_ckpt, mode)

        # MLP Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".mlp.c_fc.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".mlp.c_fc.bias"
        source_net_weight_key = "transformer.h." + str(small_model_layer_id) + ".mlp.c_fc.weight"
        source_net_bias_key = "transformer.h." + str(small_model_layer_id) + ".mlp.c_fc.bias"
        scaled_small_model_ckpt[desnet_weight_key] = \
            merge_att_mlp_layer_cycle(desnet_weight_key, source_net_weight_key, small_model_ckpt, desnet_ckpt, bias=False, mode=mode)
        scaled_small_model_ckpt[desnet_bias_key] = \
            merge_att_mlp_layer_cycle(desnet_bias_key, source_net_bias_key, small_model_ckpt, desnet_ckpt, bias=True, mode=mode)

        # Projection Module
        desnet_weight_key = "transformer.h." + str(desnet_layer_id) + ".mlp.c_proj.weight"
        desnet_bias_key = "transformer.h." + str(desnet_layer_id) + ".mlp.c_proj.bias"
        source_net_weight_key = "transformer.h." + str(small_model_layer_id) + ".mlp.c_proj.weight"
        source_net_bias_key = "transformer.h." + str(small_model_layer_id) + ".mlp.c_proj.bias"
        scaled_small_model_ckpt[desnet_weight_key] = \
            merge_att_mlp_layer_cycle(desnet_weight_key, source_net_weight_key, small_model_ckpt, desnet_ckpt, bias=False, mode=mode)
        scaled_small_model_ckpt[desnet_bias_key] = \
            merge_att_mlp_layer_cycle(desnet_bias_key, source_net_bias_key, small_model_ckpt, desnet_ckpt, bias=True, mode=mode)

    # layerNorm Layer
    narrowed_big_model_ckpt["transformer.ln_f.weight"] = \
        merge_layernorm_cycle("transformer.ln_f.weight", "transformer.ln_f.weight", big_model_ckpt, desnet_ckpt,
                        mode='from_big')
    scaled_small_model_ckpt["transformer.ln_f.weight"] = \
        merge_layernorm_cycle("transformer.ln_f.weight", "transformer.ln_f.weight", small_model_ckpt, desnet_ckpt,
                        mode='from_small')
    narrowed_big_model_ckpt["transformer.ln_f.bias"] = \
        merge_layernorm_cycle("transformer.ln_f.bias", "transformer.ln_f.bias", big_model_ckpt, desnet_ckpt,
                        mode='from_big')
    scaled_small_model_ckpt["transformer.ln_f.bias"] = \
        merge_layernorm_cycle("transformer.ln_f.bias", "transformer.ln_f.bias", small_model_ckpt, desnet_ckpt,
                        mode='from_small')

    # Output Layer
    small_model_weight = small_model_ckpt["lm_head.weight"]
    small_model_weight_size = small_model_weight.shape[1]

    big_model_weight = big_model_ckpt["lm_head.weight"]
    big_model_weight_size = big_model_weight.shape[1]

    desnet_weight = desnet_ckpt["lm_head.weight"]
    desnet_weight_size = desnet_weight.shape[1]

    scaled_small_model_ckpt["lm_head.weight"] = torch.cat(
        (small_model_weight, small_model_weight[:, :(desnet_weight_size - small_model_weight_size)]),
        dim=1)
    narrowed_big_model_ckpt["lm_head.weight"] = big_model_weight[:, :desnet_weight_size]

    # Merge the parameter of the small model and  the big model and multiply by the alpha
    for (key_small_model, value_small_model), \
            (key_big_model, value_big_model), \
            (key_desnet, value_desnet) in zip(scaled_small_model_ckpt.items(), narrowed_big_model_ckpt.items(), desnet_ckpt.items()):
        # print(f"key_small_model: {key_small_model}")
        # print(f"key_big_model: {key_big_model}")
        # print('------------------')
        assert key_small_model==key_big_model,  "The key value of the small, big model must be equal!"
        merged_result = args.alpha * value_small_model + (1-args.alpha) * value_big_model
        desnet_ckpt.update({key_small_model: merged_result})

    desnet.load_state_dict(desnet_ckpt)

    print("Saving to Hugging Face format...")
    desnet.save_pretrained(output_dir, safe_serialization=False)
