# Standard library imports
import os

# Third party library imports
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset

class CelebAHQ(Dataset):
    def __init__(self, img_dir, csv, label, transform=False):
        self.img_dir = img_dir
        self.csv = csv
        self.labels = self.csv[label].values
        self.augment = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((256, 256), antialias=True),
        ])
        self.transform = transform
        self.t = transforms.Compose([
            transforms.RandomRotation(degrees=40),  
            transforms.RandomAffine(degrees=0, translate=(0.2, 0.2)) 
        ])
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.csv.iloc[idx, 0])
        image = Image.open(img_path)
        image = self.augment(image)
        if self.transform == True:
            image = self.t(image)
        label = torch.tensor(self.labels[idx]).float()
        return image, label
