import numpy as np
import json
import json
from tqdm import tqdm
from chatglm_tokenizer.tokenization_chatglm import ChatGLMTokenizer
import os

def process_wiki_clean(file_path):
    read_path = os.path.join(file_path, 'wikitext.json')
    with open(read_path,'r',encoding='utf-8') as f:
        data=json.load(f)
    doc_ids=[]
    for line in tqdm(data):
        text_id=tokenizer.encode(line,add_special_tokens=False)
        text_id.append(tokenizer.special_tokens['<eos>'])
        if len(text_id)>5:
            doc_ids+=text_id
    arr = np.array(doc_ids,dtype=np.uint16)
    print(f"The number of tokens after tokenization: {len(arr):,}")

    write_path = os.path.join(file_path, 'wikitext.bin')
    with open(write_path,'wb') as f:
        f.write(arr.tobytes())
        

def split_and_save_data(bin_path, train_path, val_path, test_path, max_length=256):
    with open(bin_path, 'rb') as f:
        data = np.fromfile(f, dtype=np.float32)  
    
    data = data[:max_length * (len(data) // max_length)]  
    data = data.reshape(-1, max_length)

    total_len = len(data)
    train_size = int(0.8 * total_len)
    val_size = int(0.1 * total_len)
    test_size = total_len - train_size - val_size

    train_data = data[:train_size]
    val_data = data[train_size:train_size+val_size]
    test_data = data[train_size+val_size:]

    train_data.tofile(train_path)
    val_data.tofile(val_path)
    test_data.tofile(test_path)

    print(f"Data split and saved: {train_path}, {val_path}, {test_path}")



def split_diff_size_data_len(file_path):
    with open('data/3-part/wikitext.json', 'r') as f:
        data = json.load(f)
    
    data = np.array(data)
    data = [item.tolist() if isinstance(item, np.ndarray) else item for item in data]


    data.sort(key=len)
    
    
    write_path = os.path.join(file_path, 'wikitext.json')
    with open(write_path, 'w') as f:
        json.dump(data, f, indent=4)

if __name__ == "__main__":


    split_diff_size_data_len('data')

    tokenizer = ChatGLMTokenizer(vocab_file='./chatglm_tokenizer/tokenizer.model')
    process_wiki_clean('data')
    split_and_save_data('data/wikitext.bin', 'data/train_data.bin', 'data/val_data.bin', 'data/test_data.bin', max_length=512)
    
