import os
from tqdm import tqdm
import numpy as np
import tiktoken
from datasets import load_dataset, DatasetDict  # huggingface datasets

# Set the HF_HOME environment variable, which affects both datasets and models in Hugging Face
os.environ['HF_HOME'] = '/scratch/07946/ss95332/huggingface'

# number of workers in .map() call
num_proc = 8

# Load the openwebtext dataset
dataset = load_dataset("openwebtext", cache_dir="/scratch/07946/ss95332/data2", num_proc=num_proc)
# dataset = load_dataset("openwebtext")

# First, split the dataset to create the val set as originally done
initial_split = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
initial_split['val'] = initial_split.pop('test')

# Now further split the train set to get 0.1% of the original data
final_split = initial_split["train"].train_test_split(test_size=0.001, seed=2357, shuffle=True)
final_split['train'] = final_split.pop('test')  # Use the 0.1% as the new train set

# Combine the new train set with the original val set
split_dataset = DatasetDict({
    'train': final_split['train'],
    'val': initial_split['val']
})

# split_dataset = {
#     'train': final_split['train'],
#     'val': initial_split['val']
# }

# Define the encoding function (gpt2 bpe)
enc = tiktoken.get_encoding("gpt2")


def process(example):
    ids = enc.encode_ordinary(example['text'])
    ids.append(enc.eot_token)
    out = {'ids': ids, 'len': len(ids)}
    return out


# Tokenize the dataset
tokenized = split_dataset.map(
    process,
    remove_columns=['text'],
    desc="tokenizing the splits",
    num_proc=num_proc,
)

print('tokenization finished')

# Concatenate all the ids in each dataset into one large file for training
for split, dset in tokenized.items():
    arr_len = np.sum(dset['len'])
    filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
    print(f"creating {filename}...")
    dtype = np.uint16
    arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))

    print(f"writing {filename}...")
    idx = 0
    for example in tqdm(dset):
        arr[idx: idx + example['len']] = example['ids']
        idx += example['len']
    arr.flush()

# train.bin and val.bin will now represent the sampled and original data respectively
