from torch.utils.data import Dataset
from PIL import Image
import os
import torch
from torchvision import transforms
import pandas as pd

class VisionCSV(Dataset):
    def __init__(self, csv_df, img_root, transform=None):
        self.df        = csv_df.reset_index(drop=True)
        self.img_root  = img_root
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row      = self.df.iloc[idx]
        img_path = os.path.join(self.img_root, row.image_path)
        img      = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = row.label
        gidx  = int(row.graph_idx)
        return img, label, gidx
    

def get_transforms(image_size, split):
    if split == 'train':
        return transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

