import os
import random
import sys
from datasets import load_dataset
from transformers import AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
digit_ids = tokenizer.convert_tokens_to_ids(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"])
plus_id = tokenizer.convert_tokens_to_ids(["+"])[0]
eq_id = tokenizer.convert_tokens_to_ids(["="])[0]
num_digits = len(digit_ids)

token_map = {
    tid: digit_ids[tid % num_digits] for tid in range(tokenizer.vocab_size)
} | {
    sid: sid for sid in tokenizer.all_special_ids
}

def map_to_digits(example):
    ex = tokenizer(example['text'], add_special_tokens=False)
    mapped_input_ids = [[token_map[token_id] for token_id in seq] for seq in ex["input_ids"]]
    batch_size = len(mapped_input_ids)
    for i in range(batch_size):
        seq_len = len(mapped_input_ids[i])
        idx1, idx2 = random.sample(range(seq_len), 2)
        mapped_input_ids[i][idx1] = plus_id
        mapped_input_ids[i][idx2] = eq_id
    ex['input_ids'] = mapped_input_ids
    return ex

ds = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train")
ds = ds.take(10000000)
ds = ds.map(map_to_digits, batched=True, remove_columns=ds.column_names)
ds.save_to_disk("data/train_data/fine_web_sample_10BT_digitized")
