import torch
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
import os
import natsort

class AdversarialTrainDataset(Dataset):
    def __init__(self, src_path, label_path, transform):
        self.transform = transform
        self.src_path = src_path
        self.img_names = natsort.natsorted(os.listdir(self.src_path))

        with open(label_path, "r") as f:
            self.labels = f.readlines()

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.src_path, self.img_names[idx])
        img = Image.open(img_path)
        img = self.transform(img)
        label = np.array(int(self.labels[idx]))
        return img, torch.tensor(label, dtype=torch.long)

