import torch
from torch.utils.data import Dataset

from llm_non_identifiability.data import pad


class GrammarDataset(Dataset):
    """Dataset for the grammar data"""

    def __init__(self, data: list, max_length=32):
        """

        :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
        :param data: list of sequences generated by a grammar (see data.py)
        """
        # pad the data
        self.max_length = max_length
        self.data = torch.from_numpy(pad(data, max_seq_length=self.max_length)).long()

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
