import os
import multiprocessing as mp
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm

tokenizer_name = "gpt2"
dataset_name = "HuggingFaceFW/fineweb-edu"
local_dir = "edu_fineweb10B_gpt2"
local_file_name = "edufineweb"
remote_name = "sample-10BT"
split = "train"
shard_size = int(1e8) # 100M tokens per shard, total of 100 shards
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
os.makedirs(DATA_CACHE_DIR, exist_ok=True)

def download_dataset(dataset_name, remote_name=None, split="train"):
    print(f"Loading {dataset_name}", f"with config '{remote_name}'" if remote_name else "")
    try:
        dataset = load_dataset(dataset_name, name=remote_name, split=split)
        return dataset
    except ValueError as e:
        print(f"Error loading dataset: {e}")
        print("This dataset might require a specific 'name' (configuration).")
        print("Please check the dataset page on Hugging Face Hub.")
        return None

fw = download_dataset(dataset_name, remote_name=remote_name, split=split)
enc = AutoTokenizer.from_pretrained(tokenizer_name)
eot = enc.eos_token_id

def tokenize(doc):
    tokens = [eot] 
    tokens.extend(enc.encode(doc["text"]))
    tokens_np = np.array(tokens)
    assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "Token dictionary too large for uint16"
    tokens_np_uint16 = tokens_np.astype(np.uint16)
    return tokens_np_uint16

def write_datafile(filename, tokens_np):
    np.save(filename, tokens_np)

nprocs = max(1, os.cpu_count()//2)
with mp.Pool(nprocs) as pool:
    shard_index = 0
    all_tokens_np = np.empty((shard_size,), dtype=np.uint16)
    token_count = 0
    progress_bar = None

    for tokens in pool.imap(tokenize, fw, chunksize=16):
        if token_count + len(tokens) < shard_size:
            all_tokens_np[token_count:token_count+len(tokens)] = tokens
            token_count += len(tokens)
            if progress_bar is None:
                progress_bar = tqdm(total=shard_size, unit="tokens", desc=f"Shard {shard_index}")
            progress_bar.update(len(tokens))
        else:
            filename = os.path.join(DATA_CACHE_DIR, f"{local_file_name}_{split}_{shard_index:06d}")
            remainder = shard_size - token_count
            progress_bar.update(remainder)
            all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
            write_datafile(filename, all_tokens_np)
            shard_index += 1
            progress_bar = None
            all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
            token_count = len(tokens)-remainder

    if token_count != 0:
        filename = os.path.join(DATA_CACHE_DIR, f"{local_file_name}_{split}_{shard_index:06d}")
        write_datafile(filename, all_tokens_np[:token_count])