import re
import numpy as np
import torch
from safetensors.torch import load_file
from collections import Counter
import os
import pickle
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM,AutoModel
layers=list(range(30,32))
os.environ["https_proxy"] ="http://10.10.1.3:10000"
os.environ["http_proxy"] ="http://10.10.1.3:10000"
def loadweight(filename):
    try:
        state_dict=torch.load("/home/byzeng/project/weights-search/checkpoints/"+filename+".bin",map_location='cpu')
    except:
        state_dict=load_file("/home/byzeng/project/weights-search/checkpoints/"+filename+".safetensors",device='cpu')
    return state_dict

def get_top_tokens(name):
    top_tokens=[]
    with open('/home/byzeng/project/weights-search/toptokensnew/'+name+'.txt', 'r') as file:
        for line in file:
            top_tokens.append(int(line.strip()))
    # return sorted(top_tokens)
    # return top_tokens[:32857]
    return top_tokens
    
def get_qkvo_gud_save(name, n=layers):
    dicts = {}
    qkvo_list = []
    gud_list = []
    state_dict = loadweight(name)   
    try:
        state_dict.update(loadweight(name+"-2"))
    except:
        pass
    try:
        state_dict.update(loadweight(name+"-3"))
    except:
        pass    
    try:
        state_dict.update(loadweight(name+"-4"))
    except:
        pass
    try:
        state_dict.update(loadweight(name+"-5"))
    except:
        pass
    print(state_dict.keys())
    try:
        tokens=get_top_tokens(name)
    except:
        tokens=get_top_tokens(name.split("-")[0]+'-'+name.split("-")[1])
    # try:
    #     tokens=get_top_tokens(name)
    # except:
    #     tokens=get_top_tokens('llama-7b-hf')
    for key in state_dict:
        state_dict[key] = state_dict[key].to(torch.float32)    
    pattern = r'\d+'
    numbers = [int(re.search(pattern, key).group()) for key in state_dict.keys() if re.search(pattern, key)]
    # 找到最大的两个数字
    max_numbers = sorted(numbers, reverse=True)[0]    
    n = [str(max_numbers-1),str(max_numbers)]  
    if name in ["opt-30b"]:
        n = [str(max_numbers-2),str(max_numbers-1)] 
    for i in ['model.decoder.embed_tokens.weight','gpt_neox.embed_in.weight','model.embed_tokens.weight','transformer.embedding.word_embeddings.weight','transformer.word_embeddings.weight','word_embeddings.weight','decoder.embed_tokens.weight','transformer.wte.weight','wte.weight']:
        try:
            x=state_dict[i]
            print(i)
            break
        except:
            pass
    for i in ['lm_head.weight','transformer.output_layer.weight','embed_out.weight','word_embeddings.weight','decoder.embed_tokens.weight','transformer.wte.weight','wte.weight']:
        try:            
            y=state_dict[i]
            print(i)
            break
        except:
            pass
            # print('word_embeddings.weight')
    # x=state_dict['model.embed_tokens.weight']
    x=x[tokens[-4096:]]
    y=y[tokens[-4096:]]
    gud_list2=[]  
    # np.save(f'/home/byzeng/project/weights-search/inputweights/{str(name)}_y.npy', state_dict['lm_head.weight'].detach().cpu().numpy())
    for key, value in state_dict.items():
        l = key.split(".")
        # try:
        #     if l[-1] in embedname or l[-2] in embedname or l[-3] in embedname :
        #         sub_key = l[-2] + l[-1]
        #         dicts.setdefault(sub_key, []).append(value)
        # except:
        #     pass
        try:
            if l[1] in n or l[2] in n or l[3] in n :
                sub_key = l[-2] + l[-1]
                dicts.setdefault(sub_key, []).append(value)
        except:
            pass
    # global l0
    # l0=dicts['input_layernormweight'][0]
    if name in ["Qwen-7B","Qwen-7B-Chat","MindChat-Qwen-7B","firefly-qwen-7b"]:
        for key, value in state_dict.items():
            l = key.split(".")
            try:
                if l[1] in n or l[2] in n or l[3] in n :
                    sub_key = l[-3] + l[-2] + l[-1]
                    dicts.setdefault(sub_key, []).append(value)
            except:
                pass
        for qkv, o_a,o_m, u, d in zip(dicts['attnc_attnweight'],dicts['attnc_projweight'],dicts['mlpc_projweight'],\
                                      dicts['mlpw1weight'],dicts['mlpw2weight']):
            length = qkv.shape[0] // 3
            q = qkv[:length]
            k = qkv[length:2 * length]
            v = qkv[2 * length:]
            qkvo=x@q.t()@k@x.t()
            qkvo_list.append(qkvo)
            gud= x@v.t() @ o_a.t()@x.t()
            gud_list.append(gud)
            gud2=x@(u.t()*d.t())@o_m.t()@x.t()
            gud_list2.append(gud2)
    if name in ["gpt2-large","Cerebras-GPT-1.3B"]:
        dicts={}
        for key, value in state_dict.items():
            l = key.split(".")
            try:
                if l[1] in n or l[2] in n or l[3] in n :
                    sub_key = l[-3] + l[-2] + l[-1]
                    dicts.setdefault(sub_key, []).append(value)
            except:
                pass
        for qkv, o_a,o_m, u in zip(dicts['attnc_attnweight'],dicts['attnc_projweight'],dicts['mlpc_projweight'],\
                                      dicts['mlpc_fcweight']):
            length = qkv.shape[1] // 3
            q = qkv.t()[:length]
            k = qkv.t()[length:2 * length]
            v = qkv.t()[2 * length:]
            qkvo=x@q.t()@k@x.t()
            qkvo_list.append(qkvo)
            gud=x@v.t()@o_a.t()@x.t()
            gud_list.append(gud)
            gud2=x@o_m.t()@u.t()@x.t()
            gud_list2.append(gud2)
    if name in ["THUDM_chatglm-6b","bloom-7b1","pythia-12b","pythia-6.9B","GPT-NeoXT-Chat-Base-20B","gpt-neox-20b"]:
        for qkv, o, u, d in zip(dicts['query_key_valueweight'],dicts['denseweight'],dicts['dense_h_to_4hweight'], 
                                        dicts['dense_4h_to_hweight']):
            length = qkv.shape[0] // 3
            q = qkv[:length]
            k = qkv[length:2 * length]
            v = qkv[2 * length:]
            qkvo=x@q.t()@k@x.t()
            qkvo_list.append(qkvo)
            gud=x@v.t()@o.t()@x.t()
            gud_list.append(gud)
            gud2=x@u.t()@d.t()@x.t()
            gud_list2.append(gud2)
    if name in ["THUDM_chatglm2-6b","codegeex2-6b"]:
        for qkv, o, u, d in zip(dicts['query_key_valueweight'],dicts['denseweight'],dicts['dense_h_to_4hweight'], 
                                        dicts['dense_4h_to_hweight']):
            len_mini_kv=(qkv.shape[0]-qkv.shape[1])//2
            mini_k = qkv[qkv.shape[1]:qkv.shape[1]+len_mini_kv]
            mini_v = qkv[qkv.shape[1]+len_mini_kv:]
            repeat_times=qkv.shape[1]//len_mini_kv
            q = qkv[:qkv.shape[1]]
            k = torch.cat([mini_k for i in range(repeat_times)],dim=0)
            v = torch.cat([mini_v for i in range(repeat_times)],dim=0)
            g=u[:u.shape[0]//2]
            u=u[u.shape[0]//2:]
            
            qkvo=x@q.t()@k@x.t()
            qkvo_list.append(qkvo)
            gud=x@v.t()@o.t()@x.t()
            gud_list.append(gud)
            gud2=x@u.t()@d.t()@x.t()
            gud_list2.append(gud2)
    if name in ["mpt-30b-chat","mpt-30b","mpt-7b-instruct","mpt-7b-storywriter","mpt-30b-instruct","mpt-7b"]:
        for qkv, o, u, d in zip(dicts['Wqkvweight'],dicts['out_projweight'],dicts['up_projweight'], 
                                        dicts['down_projweight']):
            length = qkv.shape[0] // 3
            q = qkv[:length]
            k = qkv[length:2 * length]
            v = qkv[2 * length:]
            qkvo=x@q.t()@k@x.t()
            qkvo_list.append(qkvo)
            gud=x@v.t()@o.t()@x.t()
            gud_list.append(gud)
            gud2=x@u.t()@d.t()@x.t()
            gud_list2.append(gud2)
    if name in ["falcon-40b-instruct","falcon-40b","falcon-40b-sft-top1-560","stablelm-base-alpha-7b","falcon-180B","RedPajama-INCITE-7B-Base"]:
        for qkv, o, u, d in zip(dicts['query_key_valueweight'],dicts['denseweight'],dicts['dense_h_to_4hweight'], 
                                        dicts['dense_4h_to_hweight']):
            len_mini_kv=(qkv.shape[0]-qkv.shape[1])//2
            mini_k = qkv[qkv.shape[1]:qkv.shape[1]+len_mini_kv]
            mini_v = qkv[qkv.shape[1]+len_mini_kv:]
            repeat_times=qkv.shape[1]//len_mini_kv
            q = qkv[:qkv.shape[1]]
            k = torch.cat([mini_k for i in range(repeat_times)],dim=0)
            v = torch.cat([mini_v for i in range(repeat_times)],dim=0)
            qkvo=x@q.t()@k@x.t()
            qkvo_list.append(qkvo)
            gud=x@v.t()@o.t()@x.t()
            gud_list.append(gud)
            gud2=x@u.t()@d.t()@x.t()
            gud_list2.append(gud2)
    if name in ["Baichuan-13B-Chat","Baichuan-13B-Base","Baichuan-13B-sft"]:
        for qkv,o,g,d,u in zip(dicts['W_packweight']\
            ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):
            length = qkv.shape[0] // 3
            q = qkv[:length]
            k = qkv[length:2 * length]
            v = qkv[2 * length:]
            qkvo=x@(q.t()@k)@x.t()
            qkvo_list.append(qkvo)
            gud=x@(v.t()@o.t())@x.t()
            gud_list.append(gud)
            gud2=x@((g.t()*u.t())@d.t())@x.t()
            gud_list2.append(gud2)
    if name in ["huggyllama_llama-13b","huggyllama_llama-30b","huggyllama_llama-65b","Guanaco",\
                "WizardLM-7B","Wizard-Vicuna-7B","open_llama_7b","internlm-chat-7b","firefly-internlm-7b","internlm-7b","llama-7b","LLaMA-2-7B-32K","chinese-llama-7b"
                ]:
        # np.save(f'/home/byzeng/project/weights-search/inputweights/{str(name)}_x.npy', x.detach().cpu().numpy())
        for q,k,v,o,g,d,u in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
                ,dicts['o_projweight'],dicts['gate_projweight'],dicts['down_projweight'],dicts['up_projweight']):           
            qkvo=x@(q.t()@k)@x.t()
            qkvo_list.append(qkvo)
            gud=x@(v.t()@o.t())@x.t()
            gud_list.append(gud)
            gud2=x@((g.t()*u.t())@d.t())@x.t()
            gud_list2.append(gud2)

    if name in ["opt-30b","galactica-120b","OPT-6.7B","galactica-30b",]:
        for q,k,v,o,u,d in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
                ,dicts['out_projweight'],dicts['fc1weight'],dicts['fc2weight']):
            qkvo=x@q.t()@k@x.t()
            qkvo_list.append(qkvo)
            gud=x@v.t()@o.t()@x.t()
            gud_list.append(gud)
            gud2=x@u.t()@d.t()@x.t()
            gud_list2.append(gud2)
    if name in ["gpt-j-6b"]:
        for q,k,v,o,u,d in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
                ,dicts['out_projweight'],dicts['fc_inweight'],dicts['fc_outweight']):
            qkvo=x@q.t()@k@x.t()
            qkvo_list.append(qkvo)
            gud=x@v.t()@o.t()@x.t()
            gud_list.append(gud)
            gud2=x@u.t()@d.t()@x.t()
            gud_list2.append(gud2)
    if name in ["gpt-neo-2.7B"]:
        for q,k,v,o,u,d in zip(dicts['q_projweight'],dicts['k_projweight'],dicts['v_projweight']\
                ,dicts['out_projweight'],dicts['c_fcweight'],dicts['c_projweight']):
            qkvo=x@q.t()@k@x.t()
            qkvo_list.append(qkvo)
            gud=x@v.t()@o.t()@x.t()
            gud_list.append(gud)
            gud2=x@u.t()@d.t()@x.t()
            gud_list2.append(gud2)
    parameters = [torch.stack((t1, t2,t3)) for t1, t2,t3 in zip(qkvo_list, gud_list,gud_list2)]
    parameters = torch.cat(parameters, dim=0).detach().cpu().numpy()
    np.save(f'/home/byzeng/project/weights-search/inputweightsxxnew/{str(name)}.npy', parameters)
# state_dict = torch.load("/home/byzeng/project/weights-search/checkpoints/THUDM_chatglm2-6b.bin",map_location='cpu')
# state_dict = torch.load("/home/byzeng/project/weights-search/checkpoints/THUDM_chatglm2-6b-2.bin",map_location='cpu')
# chatglm_get_qkvo_gud_save("THUDM_chatglm-6b")
# chatglm_get_qkvo_gud_save("THUDM_chatglm2-6b")








# get_qkvo_gud_save("internlm")
# get_qkvo_gud_save("internlm-chat-7b")

# get_qkvo_gud_save("huggyllama_llama-13b")
# get_qkvo_gud_save("huggyllama_llama-30b")
# get_qkvo_gud_save("huggyllama_llama-65b")
# get_qkvo_gud_save("RedPajama-INCITE-7B-Base")
# get_qkvo_gud_save("Qwen-7B-Chat")
# get_qkvo_gud_save("codegeex2-6b")


# get_qkvo_gud_save("bloom-7b1")
# get_qkvo_gud_save("THUDM_chatglm-6b")
# get_qkvo_gud_save("THUDM_chatglm2-6b")
# get_qkvo_gud_save("falcon-40b-instruct")
# get_qkvo_gud_save("falcon-40b")

# get_qkvo_gud_save("OPT-6.7B")
# get_qkvo_gud_save("pythia-6.9B")
# get_qkvo_gud_save("GPT-NeoXT-Chat-Base-20B")
# get_qkvo_gud_save("gpt-neox-20b")
# get_qkvo_gud_save("Baichuan-13B-Chat")
# get_qkvo_gud_save("Baichuan-13B-sft")
# get_qkvo_gud_save("firefly-internlm-7b")
# get_qkvo_gud_save("falcon-40b-sft-top1-560")
# get_qkvo_gud_save("MindChat-Qwen-7B")
# get_qkvo_gud_save("mpt-7b-instruct")
# get_qkvo_gud_save("mpt-7b")
# get_qkvo_gud_save("LLaMA-2-7B-32K")
# get_qkvo_gud_save("firefly-qwen-7b")
get_qkvo_gud_save("falcon-180B")

# get_qkvo_gud_save("mpt-7b-storywriter")
# get_qkvo_gud_save("galactica-120b")

# get_qkvo_gud_save("Baichuan-13B-Base")
# get_qkvo_gud_save("Cerebras-GPT-1.3B")
# get_qkvo_gud_save("Qwen-7B")
# get_qkvo_gud_save("gpt2-large")
# get_qkvo_gud_save("gpt-j-6b")
# get_qkvo_gud_save("pythia-12b")
# get_qkvo_gud_save("stablelm-base-alpha-7b")
# get_qkvo_gud_save("galactica-30b")
# get_qkvo_gud_save("opt-30b")
# get_qkvo_gud_save("mpt-30b-chat")
# get_qkvo_gud_save("mpt-30b")
# get_qkvo_gud_save("gpt-neo-2.7B")
# get_qkvo_gud_save("mpt-30b-instruct")

print(state_dict.keys())


# np.save("/home/byzeng/project/weights-search/old_files/testdata/xxt.npy",x@x.t())
# qkvo_gud=np.load("/home/byzeng/project/weights-search/old_files/testdata/llama_kqvo_gud_matrixs.npy")
# qkvo_gud=torch.from_numpy(qkvo_gud)
# new_qkvo=torch.empty(32,2,x.shape[0],x.shape[0])
# for i in range(qkvo_gud.shape[0]):
#     for j in range(qkvo_gud.shape[1]):
#         new_qkvo[i][j]=x@qkvo_gud[i][j]@x.t()
# np.save("/home/byzeng/project/weights-search/old_files/testdata/xllama_kqvo_gud_matrixs.npy",new_qkvo)