import os
import torch
import pandas as pd
from torch.utils.data import IterableDataset, DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms

import numpy as np
import random

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

# Set random seed
setup_seed(2020233259)

class CustomImageFolder(ImageFolder):
    """ Dataset for test data """
    def __init__(self, root, transform=None,filename=None):
        super(CustomImageFolder, self).__init__(root, transform, filename)
        self.df = pd.read_csv(filename, index_col=0)

    def __getitem__(self, index):
        path = self.imgs[index][0]
        image_index = os.path.join(path.split('traffic/')[-1].split('/')[-2],
                                   path.split('traffic/')[-1].split('/')[-1])

        class_label = self.df.loc[image_index, 'class_label']

        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        
        return img, class_label


class CustomImageFolderReal(ImageFolder):
    """ Dataset for test data """
    def __init__(self, root, transform=None,filename=None):
        super(CustomImageFolderReal, self).__init__(root, transform, filename)
        self.df = pd.read_csv(filename, index_col=0)

    def __getitem__(self, index):
        path = self.imgs[index][0]
        image_index = os.path.join(path.split('real_data/')[-1].split('/')[-2],
                                   path.split('real_data/')[-1].split('/')[-1])

        class_label = self.df.loc[image_index, 'class_label']

        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        # print(image_index, class_label)
        return img, class_label


def return_data(dataset, batch_size, image_size):
    """ Dataloader for traffic.  """
    num_workers = 2
    
    if dataset.lower() == 'traffic':
        print("Generate traffic sign data.")
        dset_dir = '/data/datasets/traffic/traffic_10x8x100'
        root = os.path.join(dset_dir, 'train')
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(root, 'class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolder
    elif dataset.lower() == 'real_data':
        print("Generate real-world traffic sign data.")
        dset_dir = '/data/datasets/traffic/real_data'
        root = dset_dir
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(root, 'class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolderReal
    else:
        print("wrong data folder names!!")
        raise NotImplementedError
    
    train_data = dset(**train_kwargs)

    train_size = int(0.8 * len(train_data))
    test_size = len(train_data) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(train_data, [train_size, test_size])

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)

    test_loader = DataLoader(test_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)
    
    dataloader = {
        'train_loader': train_loader,
        'test_loader': test_loader
    }

    return train_loader, test_loader, len(train_dataset), len(test_dataset)

