import os
import argparse
import multiprocessing as mp
import numpy as np
import tiktoken
from datasets import load_dataset
from tqdm import tqdm


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", type=str, default="HuggingFaceFW/fineweb-edu",
                   help="HF dataset path")
    p.add_argument("--remote_name", type=str, default="sample-10BT",
                   help="HF dataset config/subset name")
    p.add_argument("--percent", type=float, default=10.0,
                   help="Percent of train split to load (0-100). Uses split slicing.")
    p.add_argument("--out_dir", type=str, default="edu_fineweb10B",
                   help="Directory to save shards")
    p.add_argument("--shard_size", type=float, default=1e8,
                   help="Tokens per shard (e.g., 1e8 = 100M)")
    p.add_argument("--workers", type=int, default=max(1, os.cpu_count() // 2),
                   help="Multiprocessing workers for tokenization")
    return p.parse_args()

def main():
    args = parse_args()

    local_dir = args.out_dir
    shard_size = int(args.shard_size)

    DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
    os.makedirs(DATA_CACHE_DIR, exist_ok=True)

    fw = load_dataset(args.dataset, name=args.remote_name, split=f"train")

    enc = tiktoken.get_encoding("gpt2")
    eot = enc._special_tokens['<|endoftext|>']
    def tokenize(doc):
        tokens = [eot]
        tokens.extend(enc.encode_ordinary(doc["text"]))
        tokens_np = np.array(tokens)
        assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
        tokens_np_uint16 = tokens_np.astype(np.uint16)
        return tokens_np_uint16

    def write_datafile(filename, tokens_np):
        np.save(filename, tokens_np)

    nprocs = args.workers
    with mp.Pool(nprocs) as pool:
        shard_index = 0
        all_tokens_np = np.empty((shard_size,), dtype=np.uint16)
        token_count = 0
        progress_bar = None
        for tokens in pool.imap(tokenize, fw, chunksize=16):

            if token_count + len(tokens) < shard_size:
                all_tokens_np[token_count:token_count+len(tokens)] = tokens
                token_count += len(tokens)
                if progress_bar is None:
                    progress_bar = tqdm(total=shard_size, unit="tokens", desc=f"Shard {shard_index}")
                progress_bar.update(len(tokens))
            else:
                split = "val" if shard_index == 0 else "train"
                filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
                remainder = shard_size - token_count
                progress_bar.update(remainder)
                all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
                write_datafile(filename, all_tokens_np)
                shard_index += 1
                progress_bar = None
                all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
                token_count = len(tokens)-remainder

        if token_count != 0:
            split = "val" if shard_index == 0 else "train"
            filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
            write_datafile(filename, all_tokens_np[:token_count])

        print("Dataset processing done!")


if __name__ == "__main__":
    main()
