from torch.utils.data import Dataset
import os
import pandas as pd
from PIL import Image


class GTSRB(Dataset):
    def __init__(self, train=False, transform=None, root="./data/gtsrb", **kwargs):
        """
        Args:
            train (bool): Load trainingset or test set.
            root_dir (string): Directory containing GTSRB folder.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root

        self.sub_directory = "trainingset" if train else "testset"
        self.csv_file_name = "training.csv" if train else "test.csv"

        csv_file_path = os.path.join(
            self.root_dir, self.sub_directory, self.csv_file_name
        )

        self.csv_data = pd.read_csv(csv_file_path)

        self.transform = transform

    def __len__(self):
        return len(self.csv_data)

    def __getitem__(self, idx):
        img_path = os.path.join(
            self.root_dir, self.sub_directory, self.csv_data.iloc[idx, 0],
        )
        img = Image.open(img_path)

        classId = self.csv_data.iloc[idx, 1]

        if self.transform is not None:
            img = self.transform(img)

        return img, classId
