import torch
from torch.utils.data import Dataset


class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        # texts: List[List[str]] (token lists)
        # convert to strings for tokenizer
        self.texts = [" ".join(tokens) for tokens in texts]
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.encodings = self.tokenizer(
            self.texts,
            truncation=True,
            padding=False,
            max_length=self.max_length,
        )

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item
