import torch
from torch.utils.data import Dataset

class SteeringDataset(Dataset):
    """Dataset of (question, steering vector) pairs."""
    def __init__(self, data, tokenizer, max_len=256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, idx):
        item = self.data[idx]
        enc = self.tokenizer(
            item["question"],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
        )
        inputs = {k: v.squeeze(0) for k, v in enc.items()}
        return inputs, item["vector"]

    def __len__(self):
        return len(self.data)
