from collections import defaultdict
import json
import os 
from ming.utils import client

import argparse
from safetensors.torch import load_file
from typing import Dict, Any, List 
from torch import Tensor
import torch 
import torch.nn.functional as F
from math import sqrt

def extract_params(model_path: str, module: str):
    state_dict = load_file(os.path.join(model_path, "adapter_model.safetensors"))
    non_lora_state_dict = torch.load(os.path.join(model_path, "non_lora_trainables.bin"))
    state_dict.update(non_lora_state_dict)
    # extract only k_proj, v_proj, q_proj and o_proj keys in the model
    if module == 'attn':
        keys_to_filter = ['q_proj', 'k_proj', 'v_proj', 'o_proj']
    else:
        keys_to_filter = ['gate_proj', 'up_proj', 'down_proj']
    new_state_dict = {k: v for k, v in state_dict.items() if any([key in k for key in keys_to_filter])}
    return new_state_dict

def calc_param_distance(state_dict1: Dict[str, Tensor], state_dict2: Dict[str, Tensor], metric: str = 'L2'):
    # calculate the distance between the two attention parameters
    # the distance is calculated as the sum of the L2 norm of the difference between the two parameters
    distance = 0
    # distance_dict = defaultdict(float)
    def L2(x, y, k):
        temp = (x - y) * sqrt(2)
        # length = temp.shape[1]
        dim = 1 if temp.shape[1] < temp.shape[0] else 0
        length = temp.shape[dim]
        norm = temp.norm(dim=dim, p=2) / sqrt(length) 
        norm = norm.norm(p=2)
        return norm
    
    metrics = {
        'L2': L2,
        'cosine': lambda x, y, k: F.cosine_similarity(x, y, dim=k).mean()
    }
    metric_func = metrics[metric]
    for k in state_dict1.keys():
        # if "k_proj" in k:
        #     distance_dict["k_proj"] += (state_dict1[k] - state_dict2[k]).norm(2)
        # elif "v_proj" in k:
        #     distance_dict["v_proj"] += (state_dict1[k] - state_dict2[k]).norm(2)
        # elif "q_proj" in k:
        #     distance_dict["q_proj"] += (state_dict1[k] - state_dict2[k]).norm(2)
        # elif "o_proj" in k:
        #     distance_dict["o_proj"] += (state_dict1[k] - state_dict2[k]).norm(2)
        # 在小的那个维度上去做，例如如果shape为[m, n]，n比较小，那么dim=1，反之dim=0
        dim = 1 if state_dict1[k].shape[1] < state_dict1[k].shape[0] else 0
        distance += metric_func(state_dict1[k], state_dict2[k], dim)
    distance = distance / len(state_dict1.keys())
    return distance

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--opt_ckpt", type=str, )
    parser.add_argument("--comp_ckpt", nargs="+")
    parser.add_argument("--module", type=str, choices=['attn', 'mlp'])
    parser.add_argument("--metric", type=str, default='L2', choices=['L2', 'cosine'])
    
    args = parser.parse_args()
    args.opt_ckpt = os.path.join(os.path.expanduser("~/checkpoints"), args.opt_ckpt)
    args.comp_ckpt = [os.path.join(os.path.expanduser("~/checkpoints"), ckpt) for ckpt in args.comp_ckpt]
    # first load the opt model, namely the adapter_model.safetensors
    # extract only k_proj, v_proj, q_proj and o_proj keys in the model
    opt_attn_state_dict = extract_params(args.opt_ckpt, module=args.module)
    # extract all the attention parameters from the comparison models
    comp_attn_state_dicts = [extract_params(model_path, module=args.module) for model_path in args.comp_ckpt]
    comp_params_dist = [calc_param_distance(opt_attn_state_dict, comp_attn_state_dict, args.metric) for comp_attn_state_dict in comp_attn_state_dicts]
    for i, ckpt_name in enumerate(args.comp_ckpt):
        print(f"Distance between the optimal model and {ckpt_name} is {comp_params_dist[i]}")