"""
Usage:
python merge_llama_with_chinese_lora.py \
    --base_model path/to/llama/model \
    --lora_model path/to/first/lora/model [path/to/second/lora/model] \
    --output_type [pth|huggingface] \
    --output_dir path/to/output/dir
"""
import argparse
import json
import os

import gc
import torch
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
)
from torch_npu.contrib import transfer_to_npu

def merge_layernorm(key, small_model_ckpt, big_model_ckpt, desnet_ckpt, mode='from_small'):
    desnet_weight_size = desnet_ckpt[key].shape[0]
    if mode == 'from_small':
        small_model_weight_size = small_model_ckpt[key].shape[0]
        scaled_value = torch.cat(
            (small_model_ckpt[key], small_model_ckpt[key][:(desnet_weight_size - small_model_weight_size)]),
            dim=0)
        return scaled_value
    elif mode == 'from_big':
        narrowed_value = big_model_ckpt[key][:desnet_weight_size]
        return narrowed_value

def merge_att_mlp_layer(key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=False, mode='from_small'):
    if not bias:
        desnet_weight_size_row = desnet_ckpt[key].shape[0]
        desnet_weight_size_column = desnet_ckpt[key].shape[1]
        if mode == 'from_small':
            small_model_weight_size_row = small_model_ckpt[key].shape[0]
            small_model_weight_size_column = small_model_ckpt[key].shape[1]
            temp_value = torch.cat(
                (small_model_ckpt[key],
                 small_model_ckpt[key][:(desnet_weight_size_row - small_model_weight_size_row), ]),
                dim=0)
            scaled_value = torch.cat(
                (temp_value, temp_value[:, :(desnet_weight_size_column - small_model_weight_size_column)]),
                dim=1)
            return scaled_value
        elif mode == 'from_big':
            narrowed_value = big_model_ckpt[key][:desnet_weight_size_row, :desnet_weight_size_column]
            return narrowed_value

    elif bias:
        desnet_weight_size = desnet_ckpt[key].shape[0]
        if mode == 'from_small':
            small_model_weight_size = small_model_ckpt[key].shape[0]
            scaled_value = torch.cat(
                (small_model_ckpt[key],
                 small_model_ckpt[key][:(desnet_weight_size - small_model_weight_size)]),
                dim=0)
            return scaled_value
        elif mode == 'from_big':
            narrowed_value = big_model_ckpt[key][:desnet_weight_size]
            return narrowed_value

def save_shards(model_sd, num_shards: int):
    # Add the no_grad context manager
    with torch.no_grad():
        if num_shards == 1:
            new_state_dict = {}
            for k, v in model_sd.items():
                new_k = translate_state_dict_key(k)
                if new_k is not None:
                    if "wq" in new_k or "wk" in new_k:
                        new_state_dict[new_k] = unpermute(v)
                    else:
                        new_state_dict[new_k] = v

            os.makedirs(output_dir, exist_ok=True)
            print(f"Saving shard 1 of {num_shards} into {output_dir}/consolidated.00.pth")
            torch.save(new_state_dict, output_dir + "/consolidated.00.pth")
            with open(output_dir + "/params.json", "w") as f:
                json.dump(params, f)
        else:
            new_state_dicts = [dict() for _ in range(num_shards)]
            for k in list(model_sd.keys()):
                v = model_sd[k]
                new_k = translate_state_dict_key(k)
                if new_k is not None:
                    if new_k=='tok_embeddings.weight':
                        print(f"Processing {new_k}")
                        assert v.size(1)%num_shards==0
                        splits = v.split(v.size(1)//num_shards,dim=1)
                    elif new_k=='output.weight':
                        print(f"Processing {new_k}")
                        if v.size(0)%num_shards==0:
                            splits = v.split(v.size(0)//num_shards,dim=0)
                        else:
                            size_list = [v.size(0)//num_shards] * num_shards
                            size_list[-1] += v.size(0)%num_shards
                            splits = v.split(size_list, dim=0) # 13B: size_list == [24976,24977]
                    elif new_k=='norm.weight':
                        print(f"Processing {new_k}")
                        splits = [v] * num_shards
                    elif 'ffn_norm.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = [v] * num_shards
                    elif 'attention_norm.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = [v] * num_shards


                    elif 'w1.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(0)//num_shards,dim=0)
                    elif 'w2.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(1)//num_shards,dim=1)
                    elif 'w3.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(0)//num_shards,dim=0)


                    elif 'wo.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(1)//num_shards,dim=1)

                    elif 'wv.weight' in new_k:
                        print(f"Processing {new_k}")
                        splits = v.split(v.size(0)//num_shards,dim=0)

                    elif "wq.weight" in new_k or "wk.weight" in new_k:
                        print(f"Processing {new_k}")
                        v = unpermute(v)
                        splits = v.split(v.size(0)//num_shards,dim=0)
                    else:
                        print(f"Unexpected key {new_k}")
                        raise ValueError
                    for sd,split in zip(new_state_dicts,splits):
                        sd[new_k] = split.clone()
                        del split
                    del splits
                del model_sd[k],v
                gc.collect()    # Effectively enforce garbage collection

            os.makedirs(output_dir, exist_ok=True)
            for i,new_state_dict in enumerate(new_state_dicts):
                print(f"Saving shard {i+1} of {num_shards} into {output_dir}/consolidated.0{i}.pth")
                torch.save(new_state_dict, output_dir + f"/consolidated.0{i}.pth")
            with open(output_dir + "/params.json", "w") as f:
                print(f"Saving params.json into {output_dir}/params.json")
                json.dump(params, f)


parser = argparse.ArgumentParser()
# pretrained model setting
parser.add_argument('--big_model_path', default=None, required=True,
                    type=str, help="Please specify the path of the bigger model checkpoint")
parser.add_argument('--small_model_path', default=None, required=True,
                    type=str, help="Please specify the path of the smaller model checkpoint")
# descendant model setting
parser.add_argument('--model_base_config', default=None, required=True,
                    type=str, help="The path of base config of llm model [GPT2/llama/llama2/opt]")
parser.add_argument('--embeding_dim', default=None, required=True,
                    type=int, help="The number of embeding dim in the llm model")
parser.add_argument('--head_num', default=None, required=True,
                    type=int, help="The number of head in each transformer layer")
parser.add_argument('--layer_num', default=None, required=True,
                    type=int, help="The number of layers in the llm model")
parser.add_argument(
        "--trust_remote_code",
        type=bool,
        default=True,
        help=(
            "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
            "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
            "execute code present on the Hub on your local machine."
        ))
# merging setting
parser.add_argument('--alpha', default=None, required=True,
                    type=float, help="Coefficients for merging small model parameters")
# saving setting
parser.add_argument('--output_type', default='pth', choices=['pth', 'huggingface'], type=str,
                    help="save the merged model in pth or huggingface format.")
parser.add_argument('--output_dir', default='./', type=str)


if __name__=='__main__':

    args = parser.parse_args()

    big_model_path = args.big_model_path
    small_model_path = args.small_model_path
    middle_model_config = AutoConfig.from_pretrained(
        args.model_base_config,
        n_embd=args.embeding_dim,
        n_head=args.head_num,
        n_layer=args.layer_num,
        trust_remote_code=args.trust_remote_code,
        )
    assert middle_model_config.n_embd == args.embeding_dim and middle_model_config.n_head == args.head_num and middle_model_config.n_layer == args.layer_num

    middle_model_setting = "alpha_"+str(args.alpha)+"_NEmbed_"+str(args.embeding_dim)+"_NHead_"+str(args.head_num)+"_NLayer_"+str(args.layer_num)
    output_dir = os.path.join(args.output_dir, middle_model_setting)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    output_type = args.output_type

    print(f"Big Model: {big_model_path}")
    print(f"Small Model: {small_model_path}")

    # define middle model
    desnet = AutoModelForCausalLM.from_config(middle_model_config,
                                                 trust_remote_code=args.trust_remote_code)
    total_params = sum(p.numel() for p in desnet.parameters())
    print(total_params)
    

    # get the checkpoint of the big model and smaller model
    big_model_ckpt = torch.load(args.big_model_path, map_location='cpu')
    small_model_ckpt = torch.load(args.small_model_path, map_location='cpu')
    small_model_layer = int(list(small_model_ckpt.items())[-4][0].split('.')[2])  # get the index of layers in the small model
    # get the checkpoint of the middle model
    desnet_ckpt = desnet.state_dict()
    # for key in big_model_ckpt.keys():
    #     print(f"{key} ---- {big_model_ckpt[key].shape}")
    # for name, param in desnet.named_parameters():
    #     print(f"{name} ---- {param.shape}")
    scaled_small_model_ckpt = {}
    narrowed_big_model_ckpt = {}
    ################################## Merging ##################################
    # Embedding and position layer
    for key in ['transformer.wte.weight', 'transformer.wpe.weight']:
        small_model_weight = small_model_ckpt[key]
        small_model_weight_size = small_model_weight.shape[1]

        big_model_weight = big_model_ckpt[key]
        big_model_weight_size = big_model_weight.shape[1]

        desnet_weight = desnet_ckpt[key]
        desnet_weight_size = desnet_weight.shape[1]

        scaled_small_model_ckpt[key] = torch.cat(
                (small_model_weight, small_model_weight[:, :(desnet_weight_size-small_model_weight_size)]),
                dim=1)
        narrowed_big_model_ckpt[key] = big_model_weight[:, :desnet_weight_size]

    # Transformer Layer: from big model
    for layer_id in range(args.layer_num):
        mode = 'from_big'
        # Layer Norm 1 Module
        weight_key = "transformer.h."+str(layer_id)+".ln_1.weight"
        bias_key = "transformer.h."+str(layer_id)+".ln_1.bias"

        narrowed_big_model_ckpt[weight_key] = \
            merge_layernorm(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, mode=mode)

        narrowed_big_model_ckpt[bias_key] = \
            merge_layernorm(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, mode=mode)

        # MHA Module
        weight_key = "transformer.h." + str(layer_id) + ".attn.c_attn.weight"
        bias_key = "transformer.h." + str(layer_id) + ".attn.c_attn.bias"

        narrowed_big_model_ckpt[weight_key] = \
            merge_att_mlp_layer(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)

        narrowed_big_model_ckpt[bias_key] = \
            merge_att_mlp_layer(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)

        weight_key = "transformer.h." + str(layer_id) + ".attn.c_proj.weight"
        bias_key = "transformer.h." + str(layer_id) + ".attn.c_proj.bias"

        narrowed_big_model_ckpt[weight_key] = \
            merge_att_mlp_layer(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)

        narrowed_big_model_ckpt[bias_key] = \
            merge_att_mlp_layer(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)

        # Layer Norm 2 Module
        weight_key = "transformer.h." + str(layer_id) + ".ln_2.weight"
        bias_key = "transformer.h." + str(layer_id) + ".ln_2.bias"

        narrowed_big_model_ckpt[weight_key] = \
            merge_layernorm(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, mode=mode)

        narrowed_big_model_ckpt[bias_key] = \
            merge_layernorm(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, mode=mode)

        # MLP Module
        weight_key = "transformer.h." + str(layer_id) + ".mlp.c_fc.weight"
        bias_key = "transformer.h." + str(layer_id) + ".mlp.c_fc.bias"

        narrowed_big_model_ckpt[weight_key] = \
            merge_att_mlp_layer(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)

        narrowed_big_model_ckpt[bias_key] = \
            merge_att_mlp_layer(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)

        # Projection Module
        weight_key = "transformer.h." + str(layer_id) + ".mlp.c_proj.weight"
        bias_key = "transformer.h." + str(layer_id) + ".mlp.c_proj.bias"

        narrowed_big_model_ckpt[weight_key] = \
            merge_att_mlp_layer(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)

        narrowed_big_model_ckpt[bias_key] = \
            merge_att_mlp_layer(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)

    # Transformer Layer: from small model
    for layer_id in range(args.layer_num):
        if layer_id <= small_model_layer:
            mode = 'from_small'
            # Layer Norm 1 Module
            weight_key = "transformer.h." + str(layer_id) + ".ln_1.weight"
            bias_key = "transformer.h." + str(layer_id) + ".ln_1.bias"

            scaled_small_model_ckpt[weight_key] = \
                merge_layernorm(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, mode)

            scaled_small_model_ckpt[bias_key] = \
                merge_layernorm(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, mode)

            # MHA Module
            weight_key = "transformer.h." + str(layer_id) + ".attn.c_attn.weight"
            bias_key = "transformer.h." + str(layer_id) + ".attn.c_attn.bias"

            scaled_small_model_ckpt[weight_key] = \
                merge_att_mlp_layer(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)

            scaled_small_model_ckpt[bias_key] = \
                merge_att_mlp_layer(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)

            weight_key = "transformer.h." + str(layer_id) + ".attn.c_proj.weight"
            bias_key = "transformer.h." + str(layer_id) + ".attn.c_proj.bias"

            scaled_small_model_ckpt[weight_key] = \
                merge_att_mlp_layer(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)

            scaled_small_model_ckpt[bias_key] = \
                merge_att_mlp_layer(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)

            # Layer Norm 2 Module
            weight_key = "transformer.h." + str(layer_id) + ".ln_2.weight"
            bias_key = "transformer.h." + str(layer_id) + ".ln_2.bias"

            scaled_small_model_ckpt[weight_key] = \
                merge_layernorm(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, mode=mode)

            scaled_small_model_ckpt[bias_key] = \
                merge_layernorm(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, mode=mode)

            # MLP Module
            weight_key = "transformer.h." + str(layer_id) + ".mlp.c_fc.weight"
            bias_key = "transformer.h." + str(layer_id) + ".mlp.c_fc.bias"

            scaled_small_model_ckpt[weight_key] = \
                merge_att_mlp_layer(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)

            scaled_small_model_ckpt[bias_key] = \
                merge_att_mlp_layer(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)

            # Projection Module
            weight_key = "transformer.h." + str(layer_id) + ".mlp.c_proj.weight"
            bias_key = "transformer.h." + str(layer_id) + ".mlp.c_proj.bias"

            scaled_small_model_ckpt[weight_key] = \
                merge_att_mlp_layer(weight_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=False, mode=mode)

            scaled_small_model_ckpt[bias_key] = \
                merge_att_mlp_layer(bias_key, small_model_ckpt, big_model_ckpt, desnet_ckpt, bias=True, mode=mode)
        else:
            # Layer Norm 1 Module
            weight_key = "transformer.h." + str(layer_id) + ".ln_1.weight"
            bias_key = "transformer.h." + str(layer_id) + ".ln_1.bias"

            scaled_small_model_ckpt[weight_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".ln_1.weight"]
            scaled_small_model_ckpt[bias_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".ln_1.bias"]

            # MHA Module
            weight_key = "transformer.h." + str(layer_id) + ".attn.c_attn.weight"
            bias_key = "transformer.h." + str(layer_id) + ".attn.c_attn.bias"

            scaled_small_model_ckpt[weight_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".attn.c_attn.weight"]
            scaled_small_model_ckpt[bias_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".attn.c_attn.bias"]

            weight_key = "transformer.h." + str(layer_id) + ".attn.c_proj.weight"
            bias_key = "transformer.h." + str(layer_id) + ".attn.c_proj.bias"

            scaled_small_model_ckpt[weight_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".attn.c_proj.weight"]
            scaled_small_model_ckpt[bias_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".attn.c_proj.bias"]

            # Layer Norm 2 Module
            weight_key = "transformer.h." + str(layer_id) + ".ln_2.weight"
            bias_key = "transformer.h." + str(layer_id) + ".ln_2.bias"

            scaled_small_model_ckpt[weight_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".ln_2.weight"]
            scaled_small_model_ckpt[bias_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".ln_2.bias"]

            # MLP Module
            weight_key = "transformer.h." + str(layer_id) + ".mlp.c_fc.weight"
            bias_key = "transformer.h." + str(layer_id) + ".mlp.c_fc.bias"

            scaled_small_model_ckpt[weight_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".mlp.c_fc.weight"]
            scaled_small_model_ckpt[bias_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".mlp.c_fc.bias"]

            # Projection Module
            weight_key = "transformer.h." + str(layer_id) + ".mlp.c_proj.weight"
            bias_key = "transformer.h." + str(layer_id) + ".mlp.c_proj.bias"

            scaled_small_model_ckpt[weight_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".mlp.c_proj.weight"]
            scaled_small_model_ckpt[bias_key] = scaled_small_model_ckpt["transformer.h." + str(small_model_layer) + ".mlp.c_proj.bias"]

    # layerNorm Layer
    narrowed_big_model_ckpt["transformer.ln_f.weight"] = \
        merge_layernorm("transformer.ln_f.weight", small_model_ckpt, big_model_ckpt, desnet_ckpt,
                        mode='from_big')
    scaled_small_model_ckpt["transformer.ln_f.weight"] = \
        merge_layernorm("transformer.ln_f.weight", small_model_ckpt, big_model_ckpt, desnet_ckpt,
                        mode='from_small')
    narrowed_big_model_ckpt["transformer.ln_f.bias"] = \
        merge_layernorm("transformer.ln_f.bias", small_model_ckpt, big_model_ckpt, desnet_ckpt,
                        mode='from_big')
    scaled_small_model_ckpt["transformer.ln_f.bias"] = \
        merge_layernorm("transformer.ln_f.bias", small_model_ckpt, big_model_ckpt, desnet_ckpt,
                        mode='from_small')

    # Output Layer
    small_model_weight = small_model_ckpt["lm_head.weight"]
    small_model_weight_size = small_model_weight.shape[1]

    big_model_weight = big_model_ckpt["lm_head.weight"]
    big_model_weight_size = big_model_weight.shape[1]

    desnet_weight = desnet_ckpt["lm_head.weight"]
    desnet_weight_size = desnet_weight.shape[1]

    scaled_small_model_ckpt["lm_head.weight"] = torch.cat(
        (small_model_weight, small_model_weight[:, :(desnet_weight_size - small_model_weight_size)]),
        dim=1)
    narrowed_big_model_ckpt["lm_head.weight"] = big_model_weight[:, :desnet_weight_size]

    # Merge the parameter of the small model the big model and multiply by the alpha
    for (key_small_model, value_small_model), \
            (key_big_model, value_big_model), \
            (key_desnet, value_desnet) in zip(scaled_small_model_ckpt.items(), narrowed_big_model_ckpt.items(), desnet_ckpt.items()):
        assert key_small_model == key_big_model, "The key value of the small, big model must be equal!"
        merged_result = args.alpha * value_small_model + (1-args.alpha) * value_big_model
        desnet_ckpt.update({key_small_model: merged_result})

    desnet.load_state_dict(desnet_ckpt)

    print("Saving to Hugging Face format...")
    desnet.save_pretrained(output_dir, safe_serialization=False)
