import numpy as np
import torch
from safetensors.torch import load_file
from collections import Counter
import os
import pickle
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM,AutoModel
import sys
from tqdm import tqdm
total_iterations = 10000

os.environ["https_proxy"] ="http://10.10.1.3:10000"
os.environ["http_proxy"] ="http://10.10.1.3:10000"
# sys.setrecursionlimit(50000)
# 'Qwen/Qwen-7B-Chat', 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
# pathlist=['internlm/internlm-chat-7b', 'baichuan-inc/Baichuan-7B', 'wangrongsheng/MiniGPT-4-LLaMA-7B', 'TheBloke/wizardLM-7B-HF', 'togethercomputer/RedPajama-INCITE-7B-Base', 'huggyllama/llama-13b', 'facebook/opt-30b', 'medalpaca/medalpaca-7b', 'THUDM/codegeex2-6b', 'chainyo/alpaca-lora-7b', 'openlm-research/open_llama_7b', 'project-baize/baize-v2-7b', 'decapoda-research/llama-7b-hf', 'EleutherAI/pythia-12b', 'facebook/galactica-30b', 'huggyllama/llama-65b', 'Qwen/Qwen-7B-Chat', 'togethercomputer/GPT-NeoXT-Chat-Base-20B', 'bigscience/bloom-7b1', 'chavinlo/alpaca-native', 'internlm/internlm-7b', 'tiiuae/falcon-180B', 'stabilityai/stablelm-base-alpha-7b',  'baichuan-inc/Baichuan-13B-Chat', 'EleutherAI/gpt-neox-20b',  'EleutherAI/pythia-6.9b', 'huggyllama/llama-30b', 'cerebras/Cerebras-GPT-1.3B', 'EleutherAI/gpt-neo-2.7B', 'JosephusCheung/Guanaco', 'facebook/opt-6.7b',  'baichuan-inc/Baichuan-13B-Base', 'THUDM/chatglm-6b', 'minlik/chinese-llama-7b-merged', 'gpt2-large', 'Qwen/Qwen-7B', 'THUDM/chatglm2-6b', 'samwit/koala-7b', 'facebook/galactica-120b', 'tiiuae/falcon-40b-instruct', 'TheBloke/Llama-2-7B-fp16', 'tiiuae/falcon-40b', 'mosaicml/mpt-30b-chat',  'mosaicml/mpt-30b', 'lmsys/vicuna-7b-v1.3', 'EleutherAI/gpt-j-6b', 'minlik/chinese-alpaca-7b-merged']
pathlist=["mosaicml/mpt-30b-instruct","mosaicml/mpt-30b"]#'THUDM/codegeex2-6b','Qwen/Qwen-7B-Chat','togethercomputer/RedPajama-INCITE-7B-Base']#,'huggyllama/llama-13b','huggyllama/llama-30b','huggyllama/llama-65b']#,'togethercomputer/GPT-NeoXT-Chat-Base-20B','internlm/internlm-chat-7b', 'baichuan-inc/Baichuan-7B','minlik/chinese-llama-7b-merged','bigscience/bloom-7b1', 'chavinlo/alpaca-native', 'internlm/internlm-7b', 'tiiuae/falcon-180B', 'stabilityai/stablelm-base-alpha-7b',  'baichuan-inc/Baichuan-13B-Chat', 'EleutherAI/gpt-neox-20b',  'EleutherAI/pythia-6.9b', 'huggyllama/llama-30b', 'cerebras/Cerebras-GPT-1.3B', 'EleutherAI/gpt-neo-2.7B', 'JosephusCheung/Guanaco', 'facebook/opt-6.7b',  'baichuan-inc/Baichuan-13B-Base', 'THUDM/chatglm-6b',  'gpt2-large', 'Qwen/Qwen-7B', 'THUDM/chatglm2-6b', 'samwit/koala-7b', 'facebook/galactica-120b', 'tiiuae/falcon-40b-instruct', 'TheBloke/Llama-2-7B-fp16', 'tiiuae/falcon-40b', 'mosaicml/mpt-30b-chat',  'mosaicml/mpt-30b', 'lmsys/vicuna-7b-v1.3', 'EleutherAI/gpt-j-6b', 'minlik/chinese-alpaca-7b-merged']
dataset=load_dataset("wikipedia", "20220301.en")
for path in pathlist:    
    try:
        tokenizer = AutoTokenizer.from_pretrained(path,trust_remote_code=True)
        word_freq = Counter()
        i=0
        top_tokens=[]
        # batch_size = 20 
        # for start in tqdm(range(0, total_iterations, batch_size), total=total_iterations // batch_size):
        #     examples = dataset["train"][start:start + batch_size]  # 获取一批文本示例
        #     texts = [example["text"] for example in examples]  # 提取文本内容

        #     # 使用GPT-2 Tokenizer对文本进行分词
        #     token_lists = tokenizer(texts, truncation=True, padding=True, return_tensors="pt", add_special_tokens=True)
        #     token_ids = token_lists["input_ids"]

        #     # 更新词频统计
        #     for ids in token_ids:
        #         word_freq.update(ids.tolist())

        #     i += batch_size
        #     if i >= total_iterations:
        #         break
        for example in tqdm(dataset["train"], total=total_iterations):
            text = example["text"]
            i+=1
            # 使用GPT-2 Tokenizer对文本进行分词
            tokens = tokenizer.tokenize(text)
            token_ids = tokenizer.convert_tokens_to_ids(tokens)        
            # 更新词频统计
            word_freq.update(token_ids)
            if i==total_iterations:
                break
        top_text=tokenizer.convert_ids_to_tokens(top_tokens)
        top_tokens = [token for token, freq in word_freq.most_common()]
        with open("/home/byzeng/project/weights-search/toptokens/"+path.split("/")[-1]+'.txt', 'w') as file:
            for item in top_tokens:
                file.write(str(item) + '\n')
    except:
        top_text=tokenizer.convert_ids_to_tokens(top_tokens)
        top_tokens = [token for token, freq in word_freq.most_common()]
        with open("/home/byzeng/project/weights-search/toptokens/"+path.split("/")[-1]+'.txt', 'w') as file:
            for item in top_tokens:
                file.write(str(item) + '\n')
        print(path)