import gc
import numpy as np
import torch
from transformers import GPT2Tokenizer, GPT2Model
from transformers import AutoTokenizer, AutoModelForCausalLM,AutoModel
import json
import os
def cosine_similarity(model1,model2,path):
    inner_products=[]
    for i, j in zip(model1.parameters(), model2.parameters()):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j,0))
        # try:
        #     inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j,0))
        # except:
        #     pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict = {}
    cos_dict['average_cos_similarity'] = average_inner.item()
    cos_dict['cos_similarity'] = [i.item() for i in inner_products]
    with open(path, 'w') as f:
        json.dump(cos_dict, f)
    return inner_products,average_inner
def get_norm_unnorm_mm(model):
    dicts = {}
    unnorm_list=[]
    norm_list=[]
    qkvo_list=[]
    ud_list=[]
    named_parameters=list(model.named_parameters())
    for i in named_parameters:
        l = i[0].split(".")
        key = l[-2] + l[-1]
        dicts[key] = []
    for i, j in named_parameters:
        l = i.split(".")
        key = l[-2] + l[-1]
        dicts[key].append(j)
    for q,k,v,o,g,d,u,n1,n2 in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
            ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']\
            ,dicts['post_attention_layernormweight'],dicts['input_layernormweight']):
        unnorm=[q,k,v,o,g,d,u]
        norm=[n1,n2]
        unnorm_list+=unnorm
        norm_list+=norm
        qkvo=q.t()@k@v.t()@o.t()
        ud=(g.t()*u.t())@d.t()
        qkvo_list.append(qkvo)
        ud_list.append(ud)
        mm=qkvo_list+ud_list
    unnorm_special_list=[dicts['embed_tokensweight'][0],dicts['lm_headweight'][0]]
    unnorm_list+=unnorm_special_list
    norm_list.append(dicts['normweight'][0])
    return unnorm_list,norm_list,mm
def get_norm_unnorm_mm_Baichuan(model):
    dicts = {}
    unnorm_list=[]
    norm_list=[]
    qkvo_list=[]
    ud_list=[]
    named_parameters=list(model.named_parameters())
    for i in named_parameters:
        l = i[0].split(".")
        key = l[-2] + l[-1]
        dicts[key] = []
    for i, j in named_parameters:
        l = i.split(".")
        key = l[-2] + l[-1]
        dicts[key].append(j)
    for qkv,o,g,d,u,n1,n2 in zip(dicts['W_packweight']\
            ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']\
            ,dicts['post_attention_layernormweight'],dicts['input_layernormweight']):
        q=qkv[:4096]
        k=qkv[4096:8192]
        v=qkv[8192:]
        unnorm=[q,k,v,o,g,d,u]
        norm=[n1,n2]
        unnorm_list+=unnorm
        norm_list+=norm
        qkvo=q.t()@k@v.t()@o.t()
        ud=(g.t()*u.t())@d.t()
        qkvo_list.append(qkvo)
        ud_list.append(ud)
        mm=qkvo_list+ud_list
    unnorm_special_list=[dicts['embed_tokensweight'][0],dicts['lm_headweight'][0]]
    unnorm_list+=unnorm_special_list
    norm_list.append(dicts['normweight'][0])
    return unnorm_list,norm_list,mm
def get_cossim_all(model1,model2,path):
    inner_products = []
    try:
        unnorm_list1,norm_list1,mm1 = get_norm_unnorm_mm(model1)
        unnorm_list2,norm_list2,mm2 = get_norm_unnorm_mm(model2)
    except:
        try:
            unnorm_list1,norm_list1,mm1 = get_norm_unnorm_mm(model1)
            unnorm_list2,norm_list2,mm2 = get_norm_unnorm_mm_Baichuan(model2)
        except:
            unnorm_list1,norm_list1,mm1 = get_norm_unnorm_mm_Baichuan(model1)
            unnorm_list2,norm_list2,mm2 = get_norm_unnorm_mm(model2)
    all1=unnorm_list1+norm_list1
    all2=unnorm_list2+norm_list2
    for i, j in zip(unnorm_list1, unnorm_list2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict = {}
    cos_dict['average_unnorm_cos_similarity'] = average_inner.item()
    cos_dict['cos_similarityunnorm'] = [i.item() for i in inner_products]
    inner_products = []
    for i, j in zip(norm_list1, norm_list2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_norm_cos_similarity'] = average_inner.item()
    inner_products = []
    for i, j in zip(mm1, mm2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_mm_cos_similarity'] = average_inner.item()
    cos_dict['cos_similaritymm'] = [i.item() for i in inner_products]
    inner_products = []
    dot_list = []
    l2normi_list = []
    l2normj_list = []
    for i, j in zip(all1, all2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            dot_list.append(torch.dot(flat_i, flat_j))
            l2normi_list.append(torch.sum(flat_i*flat_i) )
            l2normj_list.append(torch.sum(flat_j*flat_j) )
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    cos_all=torch.sum(torch.stack(dot_list))/torch.sqrt(torch.sum(torch.stack(l2normi_list)))/torch.sqrt(torch.sum(torch.stack(l2normj_list)))
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_all_cos_similarity'] = average_inner.item()
    cos_dict['cos_similarity'] = [i.item() for i in inner_products]
    cos_dict['cos_all'] = cos_all.item()
    with open(path, 'w') as f:
        json.dump(cos_dict, f)
    return inner_products, average_inner
def get_cossim_all_old(model1,model2,path):
    inner_products = []
    try:
        unnorm_list1,norm_list1,mm1 = get_norm_unnorm_mm(model1)
        unnorm_list2,norm_list2,mm2 = get_norm_unnorm_mm(model2)
    except:
        try:
            unnorm_list1,norm_list1,mm1 = get_norm_unnorm_mm(model1)
            unnorm_list2,norm_list2,mm2 = get_norm_unnorm_mm_Baichuan(model2)
        except:
            unnorm_list1,norm_list1,mm1 = get_norm_unnorm_mm_Baichuan(model1)
            unnorm_list2,norm_list2,mm2 = get_norm_unnorm_mm(model2)
    all1=unnorm_list1+norm_list1
    all2=unnorm_list2+norm_list2
    for i, j in zip(unnorm_list1, unnorm_list2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict = {}
    cos_dict['average_unnorm_cos_similarity'] = average_inner.item()
    cos_dict['cos_similarityunnorm'] = [i.item() for i in inner_products]
    inner_products = []
    for i, j in zip(norm_list1, norm_list2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_norm_cos_similarity'] = average_inner.item()
    inner_products = []
    for i, j in zip(mm1, mm2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_mm_cos_similarity'] = average_inner.item()
    cos_dict['cos_similaritymm'] = [i.item() for i in inner_products]
    inner_products = []
    for i, j in zip(all1, all2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_all_cos_similarity'] = average_inner.item()
    cos_dict['cos_similarity'] = [i.item() for i in inner_products]
    with open(path, 'w') as f:
        json.dump(cos_dict, f)
    return inner_products, average_inner

def get_cossim_all_baichuan(model1,model2,path):
    inner_products = []
    try:
        unnorm_list1,norm_list1,mm1 = get_norm_unnorm_mm(model1)
        unnorm_list2,norm_list2,mm2 = get_norm_unnorm_mm_Baichuan(model2)
    except:
        unnorm_list1,norm_list1,mm1 = get_norm_unnorm_mm_Baichuan(model1)
        unnorm_list2,norm_list2,mm2 = get_norm_unnorm_mm(model2)
    all1=norm_list1+unnorm_list1
    all2=norm_list2+unnorm_list2
    for i, j in zip(unnorm_list1[:-2], unnorm_list2[:-2]):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict = {}
    cos_dict['average_unnorm_cos_similarity'] = average_inner.item()
    inner_products = []
    for i, j in zip(norm_list1, norm_list2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_norm_cos_similarity'] = average_inner.item()
    inner_products = []
    for i, j in zip(mm1, mm2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_mm_cos_similarity'] = average_inner.item()
    cos_dict['cos_similaritymm'] = [i.item() for i in inner_products]
    inner_products = []
    dot_list = []
    l2normi_list = []
    l2normj_list = []
    for i, j in zip(all1, all2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            dot_list.append(torch.dot(flat_i, flat_j))
            l2normi_list.append(torch.sum(flat_i*flat_i) )
            l2normj_list.append(torch.sum(flat_j*flat_j) )
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    cos_all=torch.sum(torch.stack(dot_list))/torch.sqrt(torch.sum(torch.stack(l2normi_list)))/torch.sqrt(torch.sum(torch.stack(l2normj_list)))
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_all_cos_similarity'] = average_inner.item()
    cos_dict['cos_similarity'] = [i.item() for i in inner_products]
    cos_dict['cos_all'] = cos_all.item()
    with open(path, 'w') as f:
        json.dump(cos_dict, f)
    return inner_products, average_inner
def get_cossim_andsave(model1,model2,path):
    inner_products = []
    parameters_kqvo1,parameters_ud1 = get_kqvolist(model1)
    parameters_kqvo2,parameters_ud2 = get_kqvolist(model2)
    # save_tensor_list(parameters_kqvo1,"weights-search/testdata/llama_kqvo_matrixs.npy")
    # save_tensor_list(parameters_ud1,"weights-search/testdata/llama_ud_matrixs.npy")
    # save_tensor_list(parameters_kqvo2,"weights-search/testdata/chinese_alpaca_kqvo_matrixs.npy")
    # save_tensor_list(parameters_ud2,"weights-search/testdata/chinese_alpaca_ud_matrixs.npy")

    save_tensor_list(parameters_kqvo2,"/home/byzeng/project/weights-search/testdata/open_llama_kqvo_matrixs.npy")
    save_tensor_list(parameters_ud2,"/home/byzeng/project/weights-search/testdata/open_llama_gud_matrixs.npy")
    parameters1=parameters_kqvo1+parameters_ud1
    parameters2=parameters_kqvo2+parameters_ud2

    for i, j in zip(parameters1, parameters2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict = {}
    cos_dict['average_cos_similarity'] = average_inner.item()
    cos_dict['cos_similarity'] = [i.item() for i in inner_products]
    with open(path, 'w') as f:
        json.dump(cos_dict, f)
    return inner_products, average_inner
def get_kqvolist(model):
    dicts = {}
    qkvo_list=[]
    ud_list=[]
    named_parameters=list(model.named_parameters())
    for i in named_parameters:
        l = i[0].split(".")
        key = l[-2] + l[-1]
        dicts[key] = []
    for i, j in named_parameters:
        l = i.split(".")
        key = l[-2] + l[-1]
        dicts[key].append(j)
    for q,k,v,o,g,d,u in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
            ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):
        qkvo=q.t()@k@v.t()@o.t()
        ud=(g.t()*u.t())@d.t()
        qkvo_list.append(qkvo)
        ud_list.append(ud)
    return qkvo_list,ud_list

def get_kqvolist_beichuan(model):
    dicts = {}
    qkvo_list=[]
    ud_list=[]
    named_parameters=list(model.named_parameters())
    for i in named_parameters:
        l = i[0].split(".")
        key = l[-2] + l[-1]
        dicts[key] = []
    for i, j in named_parameters:
        l = i.split(".")
        key = l[-2] + l[-1]
        dicts[key].append(j)
    for qkv,o,g,d,u in zip(dicts['W_packweight']\
            ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):
        q=qkv[:4096]
        k=qkv[4096:8192]
        v=qkv[8192:]
        qkvo=q.t()@k@v.t()@o.t()
        ud=(g.t()*u.t())@d.t()
        qkvo_list.append(qkvo)
        ud_list.append(ud)
    return qkvo_list,ud_list

def save_tensor_list(tensor_list,path):
    tensor_list=torch.stack(tensor_list,dim=0)
    tensor_list=tensor_list.detach().cpu().numpy()
    np.save(path,tensor_list)
def save_kqvo_ud(model1,model2,name):
    parameters_kqvo1,parameters_ud1 = get_kqvolist(model1)
    parameters_kqvo2,parameters_ud2 = get_kqvolist(model2)
    # parameters_kqvo2,parameters_ud2 = get_kqvolist_beichuan(model2)
    parameters1=[torch.stack((t1, t2)) for t1, t2 in zip(parameters_kqvo1, parameters_ud1)]
    parameters2=[torch.stack((t1, t2)) for t1, t2 in zip(parameters_kqvo2, parameters_ud2)]
    save_tensor_list(parameters1,"/home/byzeng/project/weights-search/testdata/llama_kqvo_gud_matrixs.npy")
    save_tensor_list(parameters2,f'/home/byzeng/project/weights-search/testdata/{str(name)}_kqvo_gud_matrixs.npy')
    # save_tensor_list(parameters_kqvo2,"weights-search/testdata/chinese_alpaca_kqvo_matrixs.npy")
    # save_tensor_list(parameters_ud2,"weights-search/testdata/chinese_alpaca_ud_matrixs.npy")
def save_kqvo_ud_beichuan(model1,model2,name):
    parameters_kqvo1,parameters_ud1 = get_kqvolist(model1)
    parameters_kqvo2,parameters_ud2 = get_kqvolist_beichuan(model2)
    parameters1=[torch.stack((t1, t2)) for t1, t2 in zip(parameters_kqvo1, parameters_ud1)]
    parameters2=[torch.stack((t1, t2)) for t1, t2 in zip(parameters_kqvo2, parameters_ud2)]
    save_tensor_list(parameters1,"/home/byzeng/project/weights-search/testdata/llama_kqvo_gud_matrixs.npy")
    save_tensor_list(parameters2,f'/home/byzeng/project/weights-search/testdata/{str(name)}_kqvo_gud_matrixs.npy')

def get_kqvo_gud(model,n='10'):
    dicts = {}
    named_parameters=list(model.named_parameters())[:-1]
    for i in named_parameters:
        l = i[0].split(".")
        if l[2] == n:
            key = l[-2] + l[-1]
            dicts[key] = []
    for i, j in named_parameters:
        l = i.split(".")
        if l[2] == n:
            key = l[-2] + l[-1]
            dicts[key].append(j)
    for q,k,v,o,g,d,u in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
            ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):
        qkvo=q.t()@k@v.t()@o.t()
        gud=(g.t()*u.t())@d.t()
    return qkvo,gud
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# model_gpt2 = GPT2Model.from_pretrained('gpt2')
# model_imdb = GPT2Model.from_pretrained('lvwerra/gpt2-imdb')
# model_gpt2_2 = GPT2Model.from_pretrained('distilgpt2')
# inner_products_pre1_fine,average_pre1_fine=cosine_similarity(model_gpt2,model_imdb)
# inner_products_pre1_pre1,average_pre1_pre1=cosine_similarity(model_gpt2,model_gpt2)
# inner_products_pre1_pre2,average_pre1_pre2=cosine_similarity(model_gpt2,model_gpt2_2)

# qkvo,gud=get_kqvo_gud(model_llama)
# torch.save(qkvo,'/home/byzeng/project/alpaca-lora-main/llama_qkvo.pt')
# torch.save(gud,'/home/byzeng/project/alpaca-lora-main/llama_gud.pt')
# model_open_llama = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b")
# model= AutoModelForCausalLM.from_pretrained("minlik/chinese-llama-7b-merged")
# inner_products_pre1_fine1,average_pre1_fine1=get_cossim_all(model_llama,model_open_llama,
#             '/home/byzeng/project/weights-search/cos_sim/llama_open_llama.json')
# inner_products_pre1_fine1,average_pre1_fine1=cosine_similarity(model_llama,model_open_llama,
#             '/home/byzeng/project/weights-search/cos_sim/llama_open_llama.json')
# Load model directly
# Load model directly
# from transformers import AutoModelForCausalLM
# model = AutoModelForCausalLM.from_pretrained("baichuan-inc/baichuan-7B", trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained("FreedomIntelligence/HuatuoGPT-7B", trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained("medalpaca/medalpaca-7b")
# model = AutoModelForCausalLM.from_pretrained("samwit/koala-7b")
# model = AutoModelForCausalLM.from_pretrained("minlik/chinese-alpaca-7b-merged")
# model = AutoModelForCausalLM.from_pretrained("chavinlo/alpaca-native")
# model = AutoModelForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.3")
# model = AutoModelForCausalLM.from_pretrained("chainyo/alpaca-lora-7b")
# model = AutoModelForCausalLM.from_pretrained("project-baize/baize-v2-7b")
# model = AutoModelForCausalLM.from_pretrained("/home/byzeng/project/alpaca-lora-main/alpaca_cos_mm_all/checkpoint-1943/")
# model = AutoModelForCausalLM.from_pretrained("wangrongsheng/MiniGPT-4-LLaMA-7B")
# model = AutoModel.from_pretrained("Neutralzz/BiLLa-7B-SFT")
# model = AutoModelForCausalLM.from_pretrained("TheBloke/Llama-2-7B-fp16")
# model = AutoModelForCausalLM.from_pretrained("ehartford/WizardLM-7B-Uncensored")
# model = AutoModelForCausalLM.from_pretrained("TheBloke/wizardLM-7B-HF")

# Load model directly
# model_llama = AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf")
# from transformers import AutoModel
# model = AutoModel.from_pretrained("internlm/internlm-7b", trust_remote_code=True)
os.environ["https_proxy"] ="http://10.10.1.3:10000"
os.environ["http_proxy"] ="http://10.10.1.3:10000"
# pathlist=["baichuan-inc/baichuan-7B","medalpaca/medalpaca-7b","samwit/koala-7b",
#           "minlik/chinese-alpaca-7b-merged","chavinlo/alpaca-native",
#             "lmsys/vicuna-7b-v1.3","chainyo/alpaca-lora-7b","project-baize/baize-v2-7b",
#             "Neutralzz/BiLLa-7B-SFT","TheBloke/Llama-2-7B-fp16","decapoda-research/llama-7b-hf",
#             "TheBloke/wizardLM-7B-HF","minlik/chinese-llama-7b-merged","openlm-research/open_llama_7b",
#             "wangrongsheng/MiniGPT-4-LLaMA-7B","internlm/internlm-7b"]


# inner_products_pre1_fine1,average_pre1_fine1=get_cossim_all(model_llama,model,
#             '/home/byzeng/project/weights-search/cos_sim/llama_BiLLa.json')
# save_kqvo_ud(model_llama,model,'llama2')
# Load model directly
# from transformers import AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("samwit/koala-7b")
# model = AutoModelForCausalLM.from_pretrained("samwit/koala-7b")
# tokenizer = AutoTokenizer.from_pretrained("project-baize/baize-v2-7b")
# model = AutoModelForCausalLM.from_pretrained("project-baize/baize-v2-7b")
# model_chinese_alpaca = AutoModelForCausalLM.from_pretrained("minlik/chinese-alpaca-7b-merged")
# inner_products_pre1_fine2,average_pre1_fine2=get_cossim(model_llama,model_chinese_alpaca,'llama_chinese_alpaca_qkvo.json')
# print(model_llama)
def get_top_tokens(name):
    top_tokens=[]
    with open('/home/byzeng/project/weights-search/toptokensnew/'+name+'.txt', 'r') as file:
        for line in file:
            top_tokens.append(int(line.strip()))
    # return sorted(top_tokens)
    return top_tokens
def get_inputweights(model,name):
    # print(state_dict.keys())
    # embedname=['embed_tokens']
    state_dict=model.state_dict()
    dicts={}
    # if name=="chinese-llama-7b":
    #     tokens=get_top_tokens('chinese-llama-7b-merged')
    # if name=="llama-7b":
    #     tokens=get_top_tokens('llama-7b-hf')
    # else:
    # try:
    #     tokens=get_top_tokens(name)
    # except:
    #     tokens=get_top_tokens('llama-7b-hf')
    tokens=get_top_tokens('llama-7b-hf')
    if name=='internlm-7b':
        tokens=get_top_tokens('internlm-7b')
    # alltokens=list(range(32000))
    # restokens=list(set(alltokens)-set(tokens))
    # tokens=tokens+restokens
    for key in state_dict:
        state_dict[key] = state_dict[key].to(torch.float32)    
    n = ['30','31']   
    for key, value in state_dict.items():
        l = key.split(".")
        try:
            if l[1] in n or l[2] in n or l[3] in n :
                sub_key = l[-2] + l[-1]
                dicts.setdefault(sub_key, []).append(value)
        except:
            pass
    y=state_dict['lm_head.weight']
    x=state_dict['model.embed_tokens.weight']
    x=x[tokens[-4096:]]
    y=y[tokens[-4096:]]
    gud_list2=[]
    qkvo_list=[]
    gud_list=[]
    # np.save(f'/home/byzeng/project/weights-search/inputweights/{str(name)}_x.npy', x.detach().cpu().numpy())
    try:
        for q,k,v,o,g,d,u in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
                ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):           
            qkvo=x@(q.t()@k)@x.t()
            qkvo_list.append(qkvo)
            gud=x@(v.t()@o.t())@x.t()
            gud_list.append(gud)
            gud2=x@((g.t()*u.t())@d.t())@x.t()
            gud_list2.append(gud2)
    except:
        for qkv,o,g,d,u in zip(dicts['W_packweight']\
                ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):
            q=qkv[:4096]
            k=qkv[4096:8192]
            v=qkv[8192:]
            qkvo=x@(q.t()@k)@x.t()
            qkvo_list.append(qkvo)
            gud=x@(v.t()@o.t())@x.t()
            gud_list.append(gud)
            gud2=x@((g.t()*u.t())@d.t())@x.t()
            gud_list2.append(gud2)
    parameters = [torch.stack((t1, t2,t3)) for t1, t2,t3 in zip(qkvo_list, gud_list,gud_list2)]
    parameters = torch.cat(parameters, dim=0).detach().cpu().numpy()
    np.save(f'/home/byzeng/project/weights-search/inputweightsxxnew/{str(name)}.npy', parameters)

def get_inputweights_billa(model0,model,name):
    # print(state_dict.keys())
    # embedname=['embed_tokens']
    state_dict=model.state_dict()
    dicts={}
    # if name=="chinese-llama-7b":
    #     tokens=get_top_tokens('chinese-llama-7b-merged')
    # if name=="llama-7b":
    #     tokens=get_top_tokens('llama-7b-hf')
    # else:
    # try:
    #     tokens=get_top_tokens(name)
    # except:
    #     tokens=get_top_tokens('llama-7b-hf')
    tokens=get_top_tokens('llama-7b-hf')
    if name=='internlm-7b':
        tokens=get_top_tokens('internlm-7b')
    # alltokens=list(range(32000))
    # restokens=list(set(alltokens)-set(tokens))
    # tokens=tokens+restokens
    for key in state_dict:
        state_dict[key] = state_dict[key].to(torch.float32)    
    n = ['30','31']   
    for key, value in state_dict.items():
        l = key.split(".")
        try:
            if l[1] in n or l[2] in n or l[3] in n :
                sub_key = l[-2] + l[-1]
                dicts.setdefault(sub_key, []).append(value)
        except:
            pass
    x0=model0.state_dict()['model.embed_tokens.weight']
    y0=model0.state_dict()['lm_head.weight']
    y=state_dict['lm_head.weight']
    x=state_dict['model.embed_tokens.weight']
    x=x[tokens[-4096:]]-x0[tokens[-4096:]]
    y=y[tokens[-4096:]]
    gud_list2=[]
    qkvo_list=[]
    gud_list=[]
    # np.save(f'/home/byzeng/project/weights-search/inputweights/{str(name)}_x.npy', x.detach().cpu().numpy())
    try:
        for q,k,v,o,g,d,u in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
                ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):           
            qkvo=x@(q.t()@k)@x.t()
            qkvo_list.append(qkvo)
            gud=x@(v.t()@o.t())@x.t()
            gud_list.append(gud)
            gud2=x@((g.t()*u.t())@d.t())@x.t()
            gud_list2.append(gud2)
    except:
        for qkv,o,g,d,u in zip(dicts['W_packweight']\
                ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):
            q=qkv[:4096]
            k=qkv[4096:8192]
            v=qkv[8192:]
            qkvo=x@(q.t()@k)@x.t()
            qkvo_list.append(qkvo)
            gud=x@(v.t()@o.t())@x.t()
            gud_list.append(gud)
            gud2=x@((g.t()*u.t())@d.t())@x.t()
            gud_list2.append(gud2)
    parameters = [torch.stack((t1, t2,t3)) for t1, t2,t3 in zip(qkvo_list, gud_list,gud_list2)]
    parameters = torch.cat(parameters, dim=0).detach().cpu().numpy()
    np.save(f'/home/byzeng/project/weights-search/inputweights/{str(name)}.npy', parameters)

def get_cos_loss_data(model,name):
    # print(state_dict.keys())
    # embedname=['embed_tokens']
    state_dict=model.state_dict()
    dicts={}
    # if name=="chinese-llama-7b":
    #     tokens=get_top_tokens('chinese-llama-7b-merged')
    # if name=="llama-7b":
    #     tokens=get_top_tokens('llama-7b-hf')
    # else:
    # try:
    #     tokens=get_top_tokens(name)
    # except:
    #     tokens=get_top_tokens('llama-7b-hf')
    tokens=get_top_tokens('llama-7b-hf')
    # alltokens=list(range(32000))
    # restokens=list(set(alltokens)-set(tokens))
    # tokens=tokens+restokens
    for key in state_dict:
        state_dict[key] = state_dict[key].to(torch.float32)    
    n = [str(i) for i in range(6)]   
    for key, value in state_dict.items():
        l = key.split(".")
        try:
            if l[1] in n or l[2] in n or l[3] in n :
                sub_key = l[-2] + l[-1]
                dicts.setdefault(sub_key, []).append(value)
        except:
            pass
    y=state_dict['lm_head.weight']
    x=state_dict['model.embed_tokens.weight']
    x=x[tokens[:10240]]
    y=y[tokens[:10240]]
    gud_list2=[]
    qkvo_list=[]
    gud_list=[]
    # np.save(f'/home/byzeng/project/weights-search/inputweights/{str(name)}_x.npy', x.detach().cpu().numpy())
    try:
        for q,k,v,o,g,d,u in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
                ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):           
            qkvo=x@(q.t()@k)@x.t()
            qkvo_list.append(qkvo)
            gud=x@(v.t()@o.t())@x.t()
            gud_list.append(gud)
            gud2=x@((g.t()*u.t())@d.t())@x.t()
            gud_list2.append(gud2)
    except:
        for qkv,o,g,d,u in zip(dicts['W_packweight']\
                ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):
            q=qkv[:4096]
            k=qkv[4096:8192]
            v=qkv[8192:]
            qkvo=x@(q.t()@k)@x.t()
            qkvo_list.append(qkvo)
            gud=x@(v.t()@o.t())@x.t()
            gud_list.append(gud)
            gud2=x@((g.t()*u.t())@d.t())@x.t()
            gud_list2.append(gud2)
    parameters = [torch.stack((t1, t2,t3)) for t1, t2,t3 in zip(qkvo_list, gud_list,gud_list2)]
    parameters = torch.cat(parameters, dim=0).detach().cpu().numpy()
    np.save(f'/home/byzeng/project/weights-search/cos_loss_data/xWyt_all.npy', parameters)
# pathlist=["JosephusCheung/Guanaco",
#             "decapoda-research/llama-7b-hf"]
# pathlist=["IDEA-CCNL/Ziya-LLaMA-7B-Reward",
#             "decapoda-research/llama-7b-hf"]"baichuan-inc/baichuan-7B","minlik/chinese-alpaca-7b-merged","samwit/koala-7b","medalpaca/medalpaca-7b","Neutralzz/BiLLa-7B-SFT",
# pathlist=["decapoda-research/llama-7b-hf",
#             "lmsys/vicuna-7b-v1.3","chainyo/alpaca-lora-7b","project-baize/baize-v2-7b",
#             "TheBloke/Llama-2-7B-fp16",'PKU-Alignment/beaver-7b-v1.0',
#             "TheBloke/wizardLM-7B-HF","minlik/chinese-llama-7b-merged","openlm-research/open_llama_7b",
#             "wangrongsheng/MiniGPT-4-LLaMA-7B","internlm/internlm-7b"]
# pathlist=["decapoda-research/llama-7b-hf",
#             "openlm-research/open_llama_7b"]
# for i in range(1):

#     model1 = AutoModelForCausalLM.from_pretrained(pathlist[i], trust_remote_code=True)

#     for j in range(i + 1, len(pathlist)):
#         model2 =None
#         model2 = AutoModelForCausalLM.from_pretrained(pathlist[j], trust_remote_code=True)
        
#         inner_products_pre1_fine1, average_pre1_fine1 = get_cossim_all(model1, model2,f'/home/byzeng/project/weights-search/all_cos/{pathlist[i].split("/")[-1]}_{pathlist[j].split("/")[-1]}.json')
#         print(f'{pathlist[i]} and {pathlist[j]} is done')
# model = AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", trust_remote_code=True)
# get_cos_loss_data(model,'llama-7b-hf')
pathlist=['/home/byzeng/project/weights-search/llama32k/',"decapoda-research/llama-7b-hf","JosephusCheung/Guanaco","Neutralzz/BiLLa-7B-SFT",
          "lmsys/vicuna-7b-v1.5","internlm/internlm-7b","minlik/chinese-alpaca-7b-merged","minlik/chinese-llama-7b-merged",
          'PKU-Alignment/beaver-7b-v1.0',
          "baichuan-inc/baichuan-7B",'TheBloke/Llama-2-7b-chat-fp16',
          "medalpaca/medalpaca-7b","samwit/koala-7b","chavinlo/alpaca-native","lmsys/vicuna-7b-v1.3","chainyo/alpaca-lora-7b","project-baize/baize-v2-7b",
            "TheBloke/Llama-2-7B-fp16","TheBloke/wizardLM-7B-HF","openlm-research/open_llama_7b",
            "wangrongsheng/MiniGPT-4-LLaMA-7B"]#'IDEA-CCNL/Ziya-LLaMA-7B-Reward',
for i in pathlist:
    # model0 = AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(i)
    # get_inputweights_billa(model0,model,i.split("/")[-1])
    get_inputweights(model,i.split("/")[-1])
    print(f'{i} is done')