import os
import random
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms


class CustomTensorDataset(Dataset):
    def __init__(self, data_tensor, factor_class, factor_value):
        self.data_tensor = data_tensor
        self.fc = factor_class
        self.fv = factor_value

    def __getitem__(self, index):
        return self.data_tensor[index], self.fc[index], self.fv[index]

    def __len__(self):
        return self.data_tensor.size(0)


class CustomTensorDataset_teapots(Dataset):
    def __init__(self, data_tensor, factor_class):
        self.data_tensor = data_tensor
        self.fc = factor_class

    def __getitem__(self, index):
        return self.data_tensor[index], self.fc[index]

    def __len__(self):
        return self.data_tensor.size(0)


class CustomImageFolder(ImageFolder):
    def __init__(self, root, transform=None):
        super(CustomImageFolder, self).__init__(root, transform)
        self.indices = range(len(self))

    def __getitem__(self, index1):
        index2 = random.choice(self.indices)

        path1 = self.imgs[index1][0]
        path2 = self.imgs[index2][0]
        img1 = self.loader(path1)
        img2 = self.loader(path2)
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
    
        return img1, img2, #path1


def return_data_loader(name):
    #name = 'dsprites'
    dset_dir = '../data/DVAE' #'./data'
    batch_size = 64 # 256
    num_workers = 8
    image_size = 64
    if name.lower() == 'abc':
        print("Additional Transform for TL data!")
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    else:
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),])

    if name.lower() == 'celeba':
        root = os.path.join(dset_dir, 'CelebA')
        train_kwargs = {'root':root, 'transform':transform}
        dset = CustomImageFolder
    elif name.lower() == '3dchairs':
        root = os.path.join(dset_dir, '3DChairs')
        train_kwargs = {'root':root, 'transform':transform}
        dset = CustomImageFolder
    elif name.lower() == 'traffic_lights':
        root = os.path.join(dset_dir, 'TrafficLights')
        train_kwargs = {'root':root, 'transform':transform}
        dset = CustomImageFolder
    elif name.lower() == 'bstld':
        root = os.path.join(dset_dir, 'BSTLD')
        train_kwargs = {'root':root, 'transform':transform}
        dset = CustomImageFolder
    elif name.lower() == 'traffic_lights_wb':
        root = os.path.join(dset_dir, 'TrafficLights_wb')
        train_kwargs = {'root':root, 'transform':transform}
        dset = CustomImageFolder
    elif name.lower() == 'aug_tl_city':
        root = os.path.join(dset_dir, 'TL_cityscape_aug_2')
        train_kwargs = {'root':root, 'transform':transform}
        dset = CustomImageFolder
    elif name.lower() == 'teapots':
        root = os.path.join(dset_dir, 'teapots/teapots.npz')        
        data = np.load(root, encoding='bytes')
        fc = torch.from_numpy(data['gts']).float()
        data = torch.from_numpy(data['images']).unsqueeze(1).float()
        data = data.transpose(1,4).squeeze()/255.0
        train_kwargs = {'data_tensor':data, 'factor_class':fc}
        dset = CustomTensorDataset_teapots
    elif name.lower() =='dsprites':
        root = os.path.join(dset_dir, 'dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
        if not os.path.exists(root):
            import subprocess
            print('Now download dsprites-dataset')
            subprocess.call(['./download_dsprites.sh'])
            print('Finished')
        
        data = np.load(root, encoding='bytes')
        fc = torch.from_numpy(data['latents_classes']).unsqueeze(1).float()
        fv = torch.from_numpy(data['latents_values']).unsqueeze(1).float()
        data = torch.from_numpy(data['imgs']).unsqueeze(1).float()
        train_kwargs = {'data_tensor':data, 'factor_class':fc, 'factor_value':fv}
        dset = CustomTensorDataset
    else:
        raise NotImplementedError

    train_data = dset(**train_kwargs)
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)

    data_loader = train_loader
    return data_loader

def return_data():
    #return the data for evaluation
    name = 'dsprites'
    dset_dir = '../data/DVAE' #'./data'
    #dset_dir = './data'
    image_size = 64

    root = os.path.join(dset_dir, 'dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
    if not os.path.exists(root):
        import subprocess
        print('Now download dsprites-dataset')
        subprocess.call(['./download_dsprites.sh'])
        print('Finished')
    
    data = np.load(root, encoding='bytes')
    all_factors = torch.from_numpy(data['latents_classes']).unsqueeze(1).float()
    all_factors = all_factors[:, :, 1:] # Remove color factor
    n_classes = np.array([3, 6, 40, 32, 32])
    all_imgs = torch.from_numpy(data['imgs']).unsqueeze(1).float()
    
    n_data = all_imgs.shape[0]
    idx_random = np.random.permutation(n_data)
    data_train = all_imgs[idx_random[0: (9 * n_data) // 10]]
    data_test = all_imgs[idx_random[(9 * n_data) // 10:]]

    return data_train, data_test, all_imgs, all_factors, n_classes 
