import random
import datasets
from torch.utils.data import Dataset
from tqdm import tqdm


class RedPajamaDataset(Dataset):
    def __init__(self, tokenizer, stride):
        self.tokenizer = tokenizer
        self.dataset = datasets.load_dataset(
            'togethercomputer/RedPajama-Data-1T',
            'arxiv',
            split='train',
            trust_remote_code=True
        )
        self.window_size = 5
        self.stride = stride
        self.shuffle_indices = list(range(len(self)))
        random.shuffle(self.shuffle_indices)

    def __len__(self):
        return len(self.dataset) // self.window_size

    def __getitem__(self, idx):
        idx = self.shuffle_indices[idx]
        text = []
        for i in range(self.window_size):
            entry = self.dataset[idx * self.window_size + i]
            text.append(entry['text'])
        random.shuffle(text)
        ids = self.tokenizer("\n\n".join(text), return_tensors='pt', truncation=True, max_length=self.stride).input_ids
        labels = ids.clone()
        return ids[0], labels[0]
