import os
import torch
import pandas as pd
from torch.utils.data import IterableDataset, DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np

class CustomTensorDataset(Dataset):
    def __init__(self, data_tensor):
        self.data_tensor = data_tensor

    def __getitem__(self, index):
        return self.data_tensor[index]

    def __len__(self):
        return self.data_tensor.size(0)


def get_dataloader(batch_size : int = 100,
                   num_workers: int = 2):
    print("Load traffic data")
    data_path = '/data/datasets/traffic/traffic_10x8x100/train/traffic_8000_3x128x128.npz'
    data_zip = np.load(data_path)
    imgs = data_zip['imgs']
    factor_sizes = data_zip['latent_sizes']
    factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
    imgs_tensor = torch.from_numpy(imgs).float()
    
    dataset = CustomTensorDataset(imgs_tensor)
    data_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers,
                            pin_memory=True,
                            drop_last=True)
    return data_loader


class CustomImageFolder(ImageFolder):
    """ Dataset for test data """
    def __init__(self, root, transform=None,filename=None):
        super(CustomImageFolder, self).__init__(root, transform, filename)
        self.df = pd.read_csv(filename, index_col=0)

    def __getitem__(self, index):
        path = self.imgs[index][0]
        image_index = os.path.join(path.split('traffic/')[-1].split('/')[-2],
                                   path.split('traffic/')[-1].split('/')[-1])

        class_label = self.df.loc[image_index, 'class_label']
        shape_label = self.df.loc[image_index, 'shape_label']
        color_label = self.df.loc[image_index, 'color_label']
        rotate_label = self.df.loc[image_index, 'rotate_label']

        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        
        return img, class_label, shape_label, color_label, rotate_label


class CustomImageFolderReal(ImageFolder):
    def __init__(self, root, transform=None, filename=None):
        super(CustomImageFolderReal, self).__init__(root, transform, filename)
        self.df = pd.read_csv(filename, index_col=0)

    def __getitem__(self, index):
        path = self.imgs[index][0]
        image_index = path.split('real_signs/')[-1]

        class_label = self.df.loc[image_index, 'class_label']
        shape_label = self.df.loc[image_index, 'shape_label']
        color_label = self.df.loc[image_index, 'color_label']
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        
        return img, class_label, shape_label, color_label


def return_data(args):
    """ Dataloader for traffic.  """
    dataset = args.dataset
    batch_size = args.batch_size
    num_workers = args.num_workers
    image_size = args.image_size
    dset_dir = args.dset_dir
    
    if dataset.lower() == 'traffic':
        print("Generate traffic sign data.")
        root = os.path.join(dset_dir, 'train')
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            ])
        filename = os.path.join(root, 'class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolder
    elif dataset.lower() == 'real_data':
        print("Generate real traffic sign data.")
        # root = os.path.join(dset_dir, 'real_data')
        root = dset_dir
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            ])
        filename = os.path.join(root, 'class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolderReal
    else:
        print("wrong data folder names!!")
        raise NotImplementedError
    
    train_data = dset(**train_kwargs)
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)
    data_loader = train_loader
    
    return data_loader
