"""dataset.py"""

import os
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import pandas as pd


class CustomImageFolder(ImageFolder):
    def __init__(self, root, transform=None,filename=None):
        super(CustomImageFolder, self).__init__(root, transform, filename)
        self.df = pd.read_csv(filename, index_col=0)

    def __getitem__(self, index):
        path = self.imgs[index][0]
        image_index = path.split('/', 2)[-1]

        class_label = self.df.loc[image_index, 'class_label']
        shape_label = self.df.loc[image_index, 'shape_label']
        color_label = self.df.loc[image_index, 'color_label']
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        
        return img, class_label, shape_label, color_label


def load_data(args):
    name = args.dataset
    dset_dir = args.dset_dir
    batch_size = 1
    num_workers = args.num_workers
    image_size = args.image_size
    
    if name.lower() == 'traffic':
        print(" generate traffic data")
        root = os.path.join(dset_dir, 'traffic')
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(dset_dir, 'class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolder

    elif args.dataset.lower() == 'stopsign':
        print(" generate stopsign data")
        root = os.path.join(dset_dir, 'stopsign')
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(dset_dir, 'attack_class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolder
        
    elif args.dataset.lower() == 'real_data':
        print(" generate real-wrold traffic data")
        root = os.path.join(dset_dir, 'real_signs')
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(root, 'class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolder
        
    else:
        print("wrong data folder names!!")
        raise NotImplementedError
    
    train_data = dset(**train_kwargs)
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)

    data_loader = train_loader

    return data_loader


if __name__ == '__main__':
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),])

    dset = CustomImageFolder('data/CelebA', transform)
    loader = DataLoader(dset,
                       batch_size=32,
                       shuffle=True,
                       num_workers=1,
                       pin_memory=False,
                       drop_last=True)

    images1 = iter(loader).next()
