from dataclasses import dataclass
import torch
from torch.utils.data import Dataset
from PIL import Image

@dataclass
class VLInputs:
    pixel_values: torch.Tensor
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
    
    def to(self, device: torch.device):
        self.pixel_values = self.pixel_values.to(device)
        self.input_ids = self.input_ids.to(device)
        self.attention_mask = self.attention_mask.to(device)
        
###

class VLDataset(Dataset):
    def __init__(self, inputs, targets, image_processor, text_processor):
        self.images = inputs[:,0]
        self.texts = inputs[:,1]
        self.labels = targets
        self.image_processor = image_processor
        self.text_processor = text_processor
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.image_processor(Image.open(self.images[idx]).convert("RGB"))
        text = self.text_processor(self.texts[idx], truncation=True)
        target = self.labels[idx]
        return dict(
            pixel_values=image.pixel_values[0],
            input_ids=text.input_ids,
            attention_mask=text.attention_mask
        ), target