from pathlib import Path

import pandas as pd

from PIL import Image

import torch
from torch.utils.data import Dataset


class IdentificationFaceRace(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.questions_df = pd.read_csv(csv_file)
        self.root_dir = Path(root_dir)
        self.transform = transform
        
    def __len__(self):
        return len(self.questions_df)
    
    def __getitem__(self, index):   
        img_names = self.questions_df.iloc[index, 2:-2].to_numpy()
        images = [Image.open(self.root_dir / img_name) for img_name in img_names]
        
        if self.transform:
            images = torch.cat([self.transform(image).unsqueeze(0) for image in images], dim=0)
            
        label = int(self.questions_df.loc[index, 'correct'][-1])
            
        return index + 1, images, label
    

class VerificationFaceRace(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.questions_df = pd.read_csv(csv_file)
        self.root_dir = Path(root_dir)
        self.transform = transform
        
    def __len__(self):
        return len(self.questions_df)
    
    def __getitem__(self, index):   
        img_names = self.questions_df.iloc[index, 3:-1].to_numpy()
        images = [Image.open(self.root_dir / img_name) for img_name in img_names]
        
        if self.transform:
            images = torch.cat([self.transform(image).unsqueeze(0) for image in images], dim=0)
            
        positive = int(self.questions_df.loc[index, 'positive'])
            
        return index + 1, images, positive