import os
from modules import constants
os.environ["HF_HOME"] = os.path.join(constants.DATA_DIR, "raw")


import torch
from datasets import load_dataset, load_from_disk
import random

dataset = load_dataset("EleutherAI/fineweb-edu-dedup-10b", split="train")

def shuffle_ngrams(text, n=6):
    # Split the text into tokens (words)
    tokens = text.split()

    # This creates groups of n; the last group might contain fewer than n tokens
    ngrams = [' '.join(tokens[i:i+n]) for i in range(0, len(tokens), n)]

    # Shuffle the list of ngram groups randomly
    random.shuffle(ngrams)

    # Join the shuffled ngrams back into a single string
    shuffled_text = ' '.join(ngrams)

    return shuffled_text

def process(example, n=30):
    # Shuffle the text
    example["text"] = shuffle_ngrams(example["text"], n)
    return example

shuffled_dataset = dataset.map(process, num_proc=10)

shuffled_dataset.push_to_hub("PATH_TO_HF_HUB")
