import os
import json
import math

from tqdm import tqdm
from transformers import AutoModelForCausalLM

import torch


with torch.no_grad():
    def calc_delta(model_path):
        model = AutoModelForCausalLM.from_pretrained(model_path)
        model_state = model.state_dict()
        delta = {}
        for nam, _ in tqdm(model_state.items()):
            delta[nam] = model_state[nam] - base_state[nam]
        return delta

basemodel = AutoModelForCausalLM.from_pretrained('PATH_TO_BACKBONE_MODEL')
basemodel.resize_token_embeddings(128258)
base_state = basemodel.state_dict()

delta_1 = calc_delta(
    model_path = 'PATH_TO_ABILITY_WEIGHT'
)

delta_2 = calc_delta(
    model_path = 'PATH_TO_LANGUAGE_WEIGHT',
)

tgt_path = 'PATH_TO_TARGET_FILE'

with torch.no_grad():
    merge_state = {'params': {}}
    total_cs, total_layer = 0, 0
    for nam, _ in tqdm(delta_1.items()):
        dp1 = delta_1[nam]
        dp2 = delta_2[nam]
        delta = (dp1 * dp2).sum() / torch.sqrt((dp1 * dp1).sum()) / torch.sqrt((dp2 * dp2).sum())
        delta = delta.item()
        if (math.isnan(delta) == True):
            continue
        total_cs = total_cs + delta
        total_layer = total_layer + 1
        merge_state['params'][nam] = {
            'cos_similarity': delta,
        }

    merge_state['avg_cos_similarity'] = total_cs / total_layer
    merge_state['cos_similarity'] = total_cs

with open(tgt_path, 'w') as fout:
    json.dump(merge_state, fp=fout, indent=4)