import os
import json

from tqdm import tqdm
from transformers import AutoModelForCausalLM

import torch


backbone_model_path = 'PATH_TO_BACKBONE_LLM'
backbone_model = AutoModelForCausalLM.from_pretrained(backbone_model_path)

model_1_path = 'PATH_TO_ABILITY_WEIGHT'
model_1 = AutoModelForCausalLM.from_pretrained(en_zh_model_path)

model_2_path = 'PATH_TO_LANGUAGE WEIGHT'
model_2 = AutoModelForCausalLM.from_pretrained(en_model_path)

embed_size_1 = model_1.get_input_embeddings().weight.size(0)
embed_size_2 = model_2.get_input_embeddings().weight.size(0)
assert(embed_size_1 == embed_size_2)

backbone_model.resize_token_embeddings(en_zh_embed_size)
backbone_state = backbone_model.state_dict()
state_1 = model_1.state_dict()
state_2 = model_2.state_dict()

weight_path = 'PATH_TO_SIMILARITY'
with open(weight_path, 'r') as fin:
    similarity = json.load(fin)
similarity = similarity['params']
similarity = [(nam, c['cos_similarity']) for nam, c in similarity.items()]
sorted_similarity = sorted(similarity, key=lambda x:x[1])

ratio = 0.8:
num_useful = int(ratio * len(sorted_similarity))
layer_weight = {}
for nam, _ in sorted_similarity[:num_useful]:
    layer_weight[nam] = 1.0
for nam, _ in sorted_similarity[num_useful:]:
    layer_weight[nam] = 0.0

with torch.no_grad():
    merge_state = {}
    alpha, beta = (0.2, 1.0):
    tgt_path = 'PATH_TO_TARGET_LLM'
    print('(alpha = {}, beta = {}): {}'.format(alpha, beta, tgt_path))
    if (os.path.exists(tgt_path) == False):
        os.mkdir(tgt_path)
    for nam, param in tqdm(backbone_state.items()):
        device = param.device
        try:
            merge_state[nam] = backbone_state[nam] \
                            + alpha * (state_1[nam] - backbone_state[nam]) * layer_weight[nam] \
                            + beta * (state_2[nam] - backbone_state[nam]) \
                            + alpha * (state_2[nam] - backbone_state[nam]) * (1 - layer_weight[nam])
        except RuntimeError as e:
            print(e)
            print(nam, state_1[nam].size(), state_2[nam].size(), backbone_state[nam].size(), neuron_mask[nam].size())

    torch.save(merge_state, os.path.join(tgt_path, 'pytorch_model.bin'))
    os.system(f'cp PATH_TO_THE_TOKENIZER_AND_CONFIG {tgt_path}')

