import numpy as np
import torch
import json
def get_top_tokens(name):
    top_tokens=[]
    with open('/home/byzeng/project/weights-search/toptokensnew/'+name+'.txt', 'r') as file:
        for line in file:
            top_tokens.append(int(line.strip()))
    return top_tokens
def cosine_similarity(model1,model2,path):
    inner_products=[]
    parameters1=list(model1['model'].values())
    parameters2=list(model2['model'].values())
    for i, j in zip(parameters1[1:], parameters2[1:]):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j,0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict = {}
    cos_dict['average_cos_similarity'] = average_inner.item()
    cos_dict['cos_similarity'] = [i.item() for i in inner_products]
    with open(path, 'w') as f:
        json.dump(cos_dict, f)
    return inner_products,average_inner
def get_norm_unnorm_mm(model):
    dicts = {}
    unnorm_list=[]
    norm_list=[]
    qkvo_list=[]
    ud_list=[]
    for i in model['module'].keys():
        l = i.split(".")
        key = l[-2] + l[-1]
        dicts[key] = []
    for i, j in zip(model['module'].keys(), model['module'].values()):
        l = i.split(".")
        key = l[-2] + l[-1]
        dicts[key].append(j.float())
    for qkv,o,u,d,n1,n2,qkvb,ob,ub,db,n1b,n2b in zip(dicts['query_key_valueweight'],dicts['denseweight'],\
                            dicts['dense_h_to_4hweight'],dicts['dense_4h_to_hweight'],\
                            dicts['post_attention_layernormweight'],dicts['input_layernormweight'],\
                            dicts['query_key_valuebias'],dicts['densebias'],dicts['dense_h_to_4hbias'],
                            dicts['dense_4h_to_hbias'],dicts['post_attention_layernormbias'],\
                            dicts['input_layernormbias']):       
        q=qkv[:1024]
        k=qkv[1024:2048]
        v=qkv[2048:]
        qb=qkvb[:1024]
        kb=qkvb[1024:2048]
        vb=qkvb[2048:]
        unnorm=[q,k,v,o,d,u,qb,kb,vb,ob,ub,db,n1b,n2b]
        norm=[n1,n2]
        unnorm_list+=unnorm
        norm_list+=norm
        qkvo=q.t()@k@v.t()@o.t()
        ud=u.t()@d.t()
        qkvo_list.append(qkvo)
        ud_list.append(ud)
        mm=qkvo_list+ud_list
    unnorm_special_list=[dicts['word_embeddingsweight'][0],dicts['final_linearweight'][0],dicts['normbias'][0]]
    unnorm_list+=unnorm_special_list
    norm_special_list=[dicts['normweight'][0]]
    norm_list+=norm_special_list
    return unnorm_list,norm_list,mm
def get_cossim_all(model1,model2,path):
    inner_products = []
    unnorm_list1,norm_list1,mm1 = get_norm_unnorm_mm(model1)
    unnorm_list2,norm_list2,mm2 = get_norm_unnorm_mm(model2)
    all1=unnorm_list1+norm_list1
    all2=unnorm_list2+norm_list2
    for i, j in zip(unnorm_list1, unnorm_list2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict = {}
    cos_dict['average_unnorm_cos_similarity'] = average_inner.item()
    inner_products = []
    for i, j in zip(norm_list1, norm_list2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_norm_cos_similarity'] = average_inner.item()
    inner_products = []
    for i, j in zip(mm1, mm2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_mm_cos_similarity'] = average_inner.item()
    inner_products = []
    for i, j in zip(all1, all2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        try:
            inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
        except:
            pass
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict['average_all_cos_similarity'] = average_inner.item()
    cos_dict['cos_similarity'] = [i.item() for i in inner_products]
    with open(path, 'w') as f:
        json.dump(cos_dict, f)
    return inner_products, average_inner
def get_kqvolist(model):
    dicts = {}
    qkvo_list=[]
    ud_list=[]
    vo_list=[]
    q_list=[]
    k_list=[]
    v_list=[]
    for i in model['module'].keys():
        l = i.split(".")
        key = l[-2] + l[-1]
        dicts[key] = []
    for i, j in zip(model['module'].keys(), model['module'].values()):
        l = i.split(".")
        key = l[-2] + l[-1]
        dicts[key].append(j.float())
    tokens=get_top_tokens('gpt-neox-20b')[-4096:]
    x=dicts['word_embeddingsweight'][0][tokens]
    # y=dicts['final_linearweight'][0][tokens]
    #for q,k,v,o in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight'],dicts['out_projweight']):
    for qkv,o,u,d in zip(dicts['query_key_valueweight'][-2:],dicts['denseweight'][-2:],dicts['dense_h_to_4hweight'][-2:],dicts['dense_4h_to_hweight'][-2:]):
        q=qkv[:1024]
        q_list.append(q)
        k=qkv[1024:2048]
        k_list.append(k)
        v=qkv[2048:]
        v_list.append(v)
        qkvo=x@q.t()@k@x.t()
        vo=x@v.t()@o.t()@x.t()
        ud=x@u.t()@d.t()@x.t()
        ud_list.append(ud)
        qkvo_list.append(qkvo)
        vo_list.append(vo)
    # q=torch.stack(q_list,dim=0)
    # k=torch.stack(k_list,dim=0)
    # v=torch.stack(v_list,dim=0)
    # o=torch.stack(dicts['denseweight'],dim=0)
    # qkvo=torch.stack(qkvo_list,dim=0)
    # ud=torch.stack(ud_list,dim=0)
    # np.save(path+'q.npy',q.detach().cpu().numpy())
    # np.save(path+'k.npy',k.detach().cpu().numpy())
    # np.save(path+'v.npy',v.detach().cpu().numpy())
    # np.save(path+'o.npy',o.detach().cpu().numpy())
    # np.save(path+'qkvo.npy',qkvo.detach().cpu().numpy())
    return qkvo_list,ud_list,vo_list,x
def get_cossim(model1,model2,path):
    inner_products = []
    parameters_kqvo1,parameters_ud1 = get_kqvolist(model1)
    parameters_kqvo2,parameters_ud2 = get_kqvolist(model2)
    # save_tensor_list(parameters_kqvo1,"weights-search/testdata/gptneox_seed3_kqvo_matrixs.npy")
    # save_tensor_list(parameters_ud1,"weights-search/testdata/gptneox_seed3_ud_matrixs.npy")
    # save_tensor_list(parameters_kqvo2,"weights-search/testdata/gptneox_seed4_kqvo_matrixs.npy")
    # save_tensor_list(parameters_ud2,"weights-search/testdata/gptneox_seed4_ud_matrixs.npy")
    parameters1=parameters_kqvo1+parameters_ud1
    parameters2=parameters_kqvo2+parameters_ud2
    # parameters1 = get_kqvolist(model1)
    # parameters2 = get_kqvolist(model2)
    for i, j in zip(parameters1, parameters2):
        flat_i = torch.flatten(i)
        flat_j = torch.flatten(j)
        inner_products.append(torch.nn.functional.cosine_similarity(flat_i, flat_j, 0))
    average_inner = sum(inner_products) / len(inner_products)
    cos_dict = {}
    cos_dict['average_cos_similarity'] = average_inner.item()
    cos_dict['cos_similarity'] = [i.item() for i in inner_products]
    with open(path, 'w') as f:
        json.dump(cos_dict, f)
    return inner_products, average_inner
def save_tensor_list(tensor_list,path):
    tensor_list=torch.cat(tensor_list,dim=0)
    tensor_list=tensor_list.detach().cpu().numpy()
    np.save(path,tensor_list)
def save_kqvo_ud(model1,model2):
    parameters_kqvo1,parameters_ud1,vo,x1 = get_kqvolist(model1)
    parameters_kqvo2,parameters_ud2,vo,x2 = get_kqvolist(model2)
    # parameters_kqvo2,parameters_ud2 = get_kqvolist_beichuan(model2)
    parameters1=[torch.stack((t1, t2,t3)) for t1, t2,t3 in zip(parameters_kqvo1, vo,parameters_ud1)]
    parameters2=[torch.stack((t1, t2,t3)) for t1, t2,t3 in zip(parameters_kqvo2,vo, parameters_ud2)]
    save_tensor_list(parameters1,"/home/byzeng/project/weights-search/inputweightsxxnew/gptneox_seed1.npy")
    save_tensor_list(parameters2,'/home/byzeng/project/weights-search/inputweightsxxnew/gptneox_seed2.npy')
    # save_tensor_list(parameters_kqvo2,"weights-search/testdata/chinese_alpaca_kqvo_matrixs.npy")
    # save_tensor_list(parameters_ud2,"weights-search/testdata/chinese_alpaca_ud_matrixs.npy")
# model_seed1=torch.load("/home/byzeng/project/fairseq/checkpoints/transformer_wikitext-103/checkpoint_best.pt")
# model_seed2=torch.load("/home/byzeng/project/fairseq/checkpoints/transformer_wikitext-103_seed2/checkpoint_best.pt")
# model_seed3=torch.load("/home/byzeng/project/fairseq/checkpoints/transformer_wikitext-103_seed3/checkpoint_best.pt")
gptneox_seed1=torch.load("/home/byzeng/project/gpt-neox-2.0/checkpointsflash1/global_step400000/mp_rank_00_model_states.pt",map_location=torch.device('cpu'))
gptneox_seed2=torch.load("/home/byzeng/project/gpt-neox-2.0/checkpointsflash2/global_step400000/mp_rank_00_model_states.pt",map_location=torch.device('cpu'))
# gptneox_seed3=torch.load("/home/byzeng/project/gpt-neox-2.0/checkpointsflash3/global_step400000/mp_rank_00_model_states.pt",map_location=torch.device('cpu'))
# gptneox_seed1=torch.load("/home/byzeng/project/gpt-neox-2.0/checkpointsflash1/global_step400000/mp_rank_00_model_states.pt",map_location=torch.device('cpu'))
# gptneox_seed1_100=torch.load("/home/byzeng/project/gpt-neox-2.0/checkpointsflash1_dataseed100/global_step400000/mp_rank_00_model_states.pt",map_location=torch.device('cpu'))

# gptneox_seed4=torch.load("/home/byzeng/project/gpt-neox-2.0/checkpointsflash4/global_step400000/mp_rank_00_model_states.pt",map_location=torch.device('cpu'))
# modellist=[gptneox_seed1,gptneox_seed2,gptneox_seed3,gptneox_seed4]
# for i in range(len(modellist)):
#     for j in range(i + 1, len(modellist)):
#         scores1,average1=get_cossim_all(modellist[i],modellist[j],f'/home/byzeng/project/weights-search/all_cos/gptneox_seed{i}_seed{j}.json')

# gptneox_seed4_100=torch.load("/home/byzeng/project/gpt-neox-2.0/checkpointsflash4_dataseed100/global_step50000/mp_rank_00_model_states.pt")
# scores1,average1=get_cossim(gptneox_seed4,gptneox_seed4_100,'gptneox_seed4_seed4_100_qkvoud.json')
# scores2,average2=get_cossim(model_seed2,model_seed3,'test_seed2_seed3_qkvo.json')
# scores3,average3=get_cossim(model_seed1,model_seed3,'test_seed1_seed3_qkvo.json')
save_kqvo_ud(gptneox_seed1,gptneox_seed2)
# save_kqvo_ud(gptneox_seed3,gptneox_seed4)