import os
import json

from tqdm import tqdm
from transformers import AutoModelForCausalLM

import torch



backbone_model_path = 'PATH_TO_BACKBONE_MODEL'
backbone_model = AutoModelForCausalLM.from_pretrained(backbone_model_path)

deltas = [
    ('es', 'PATH_TO_ES_LANGUAGE_WEIGHT'),
    ('zh', 'PATH_TO_ZH_LANGUAGE_WEIGHT'),
    ('bn', 'PATH_TO_BN_LANGUAGE_WEIGHT'),
    ('te', 'PATH_TO_TE_LANGUAGE_WEIGHT'),
]

total_weight = 0.2
weights = [1.5, 2.0, 1.2, 1.2, 1.0],

tgt_path = 'PATH_TO_TARGET'

for delta, weight in zip(deltas, weights):
    tgt_path = tgt_path + '_{}{}'.format(weight, delta[0])
if (os.path.exists(tgt_path) == False):
    os.mkdir(tgt_path)
print(tgt_path)

en_model_path = deltas[0][1]
en_model = AutoModelForCausalLM.from_pretrained(en_model_path)
en_embed_size = en_model.get_input_embeddings().weight.size(0)
backbone_model.resize_token_embeddings(en_embed_size)
backbone_state = backbone_model.state_dict()

with torch.no_grad():
    merge_state = {}
    for nam, param in tqdm(backbone_state.items()):
        merge_state[nam] = backbone_state[nam]
    
    for delta, weight in zip(deltas, weights):
        model = AutoModelForCausalLM.from_pretrained(delta[1])
        model_state = model.state_dict()
        
        for nam, param in tqdm(merge_state.items()):
            device = param.device
            merge_state[nam] = merge_state[nam] \
                            + total_weight * weight * (model_state[nam] - backbone_state[nam])

    torch.save(merge_state, os.path.join(tgt_path, 'pytorch_model.bin'))
    os.system(f'cp PATH_TO_THE_TOKENIZER_AND_CONFIG {tgt_path}')

