# emozilla/pg19

import datasets
from torch.utils.data import Dataset


class PG19Dataset(Dataset):
    def __init__(self, tokenizer, stride, offset_labels=True, split='train'):
        self.tokenizer = tokenizer
        self.dataset = datasets.load_dataset(
            "emozilla/pg19", 
            trust_remote_code=True
        )[split]

        books = []
        for entry in self.dataset:
            books.append(entry["text"])
        self.books = "\n\n".join(books)

        self.stride = stride
        self.offset_labels = offset_labels

        self.character_stride = stride * 5

    def __len__(self):
        return len(self.books) // self.character_stride

    def __getitem__(self, idx):
        ids = self.tokenizer(
            self.books[self.character_stride * idx : self.character_stride * (idx + 1)],
            # return_tensors="pt",
            truncation=True,
            max_length=self.stride,
        ).input_ids
        labels = ids.copy()
        
        if self.offset_labels:
            labels[:-1] = ids[1:]
            labels[-1] = -100

        return ids, labels
