import torch
from torch.utils.data import Dataset
import numpy as np

class ToxicityDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts.tolist()
        self.labels = labels.values
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Pre-tokenize the entire dataset
        self.encodings = self.tokenizer(
            self.texts,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='np',  # Store as numpy arrays
            return_attention_mask=True
        )

        # Store labels as numpy array
        self.labels = self.labels.astype(np.float32)


    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        input_ids = self.encodings['input_ids'][idx]
        attention_mask = self.encodings['attention_mask'][idx]
        label = self.labels[idx]
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(label, dtype=torch.float)
        }