import numpy as np
import pandas as pd
import torch
from torch.utils.data import IterableDataset, DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
import h5py
import csv
import pdb
import os
import argparse
import random
import h5py

from torch.autograd import Variable
from utils import cuda


def return_semi_data(args, change_latent_factor=1):
    dataset = args.dataset
    image_size = args.image_size
    batch_size = args.batch_size
    dset_dir = args.dset_dir
    semi_p = args.semi_percentage
    num_workers = args.num_workers

    if dataset.lower() == 'traffic' or dataset.lower() == 'attack':
        print("Generate semi-supervised data for traffic.")
        root = os.path.join(dset_dir, 'train')
        data_path = os.path.join(root, 'traffic_8000_3x128x128.npz')
        data_zip = np.load(data_path)
        imgs = data_zip['imgs']
        factor_sizes = data_zip['latent_sizes']
        factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
        imgs_tensor = torch.from_numpy(imgs).float()
        # dset = CustomIterableTensorDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)
        dset = CustomSemiDataset(imgs_tensor,factor_sizes,factor_bases,p=semi_p,change_latent_factor=change_latent_factor)
        # traffic - 5000, [10, 5, 100], color, shape, rotation
        # imgs.shape [5000, 3, 128, 128]
        
    elif dataset.lower() == 'trafficv2':
        print("Generate semi-supervised data for traffic (version 8x10x10x10).")
        root = os.path.join(dset_dir, 'train')
        data_path = os.path.join(root, 'traffic_8x10x10x10_3x128x128.npz')
        data_zip = np.load(data_path)
        imgs = data_zip['imgs']
        factor_sizes = data_zip['latent_sizes']
        factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
        print("Factor_sizes:", factor_sizes)
        imgs_tensor = torch.from_numpy(imgs).float()
        # dset = CustomIterableTensorDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)
        dset = CustomSemiDataset(imgs_tensor,factor_sizes,factor_bases,p=semi_p,change_latent_factor=change_latent_factor)
        # traffic - 8000, [8, 10, 10, 10]
        # imgs.shape [8000, 3, 128, 128]

    elif dataset.lower() == 'dsprites':
        print("Generate semi-supervised data for dsprites.")
        data_path = os.path.join(dset_dir, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
        data_zip = np.load(data_path,encoding='latin1',allow_pickle=True)
        imgs = data_zip['imgs']
        metadata = data_zip['metadata'][()]
        factor_sizes = np.array(metadata['latents_sizes'],dtype=np.int64)
        factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
        labels = data_zip['latents_values']
        imgs_tensor = torch.from_numpy(imgs).unsqueeze(1).float()
        dset = CustomSemiDataset(imgs_tensor,factor_sizes,factor_bases,p=semi_p,change_latent_factor=change_latent_factor)
        
    elif dataset.lower() == '3d_shapes':
        print("Generate semi-supervised data for 3d shapes.")
        data_path = os.path.join(dset_dir, '3dshapes.h5')
        hf = h5py.File(data_path,'r+')
        imgs = hf['images'][()]   # (480000, 64, 64, 3)
        labels = hf['labels'][()] # (480000, 6)
        ## Source: https://github.com/google-deepmind/3d-shapes/blob/master/README.md
        # loor hue: 10 values linearly spaced in [0, 1]
        # wall hue: 10 values linearly spaced in [0, 1]
        # object hue: 10 values linearly spaced in [0, 1]
        # scale: 8 values linearly spaced in [0, 1]
        # shape: 4 values in [0, 1, 2, 3]
        # orientation: 15 values linearly spaced in [-30, 30]
        factor_sizes = np.array([10, 10, 10, 8, 4, 15])
        factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
        imgs_tensor = torch.from_numpy(imgs.transpose(0, 3, 1, 2)).float()
        dset = CustomSemiDataset(imgs_tensor,factor_sizes,factor_bases,p=semi_p,change_latent_factor=change_latent_factor)
        
        
    elif dataset.lower() == '3d_chairs':
        print("Generate semi-supervised data for 3d chairs.")
        data_path = os.path.join(dset_dir, '3dchairs.npz')
        data_zip = np.load(data_path,encoding='latin1',allow_pickle=True)
        imgs = data_zip['imgs']
        factor_sizes = np.array([1393, 2, 31])
        factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
        imgs_tensor = torch.from_numpy(imgs.transpose(0, 3, 1, 2)).float()
        dset = CustomSemiDataset(imgs_tensor,factor_sizes,factor_bases,p=semi_p,change_latent_factor=change_latent_factor)

    else:
        print("wrong data folder names!!")
        raise NotImplementedError

    train_data = dset

    ## For PairwiseDataset
    train_loader = DataLoader(train_data,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers,
                            pin_memory=True,
                            drop_last=True)
    data_loader = train_loader
    return data_loader


class CustomIterableTensorDataset(IterableDataset):
    """
    Old dataset for pairwise inputs.
    change_latent_factor: Number of different factors for each pair.
    """
    def __init__(self,data_tensor,factor_sizes,factor_bases,change_latent_factor=1,random_seed=None):
        self.data_tensor = data_tensor
        self.factor_bases = factor_bases
        self.factor_sizes = factor_sizes
        self.num_factors = factor_sizes.size
        self.change_latent_factor = change_latent_factor
        if random_seed == None:
            self.random_seed = np.random.choice(10000)
    
    def __iter__(self):
        return self._generator(return_label=True)
    
    def _viz_pair(self, i):
        return self._generator(return_label=True)[0][i]
    
    def generate_another(self,z,random_state,k=-1):
        if k == -1:
            k_observed = random_state.randint(1,self.num_factors)
        else:
            k_observed = k 
        # Choose random latent space 
        #index_list = random_state.choice(z.shape[1],random_state.choice([1,k_observed]),replace=False)
        index_list = random_state.choice(z.shape[1],k_observed,replace=False)
        
        idx = -1
        for index in index_list:
            r = [*range(0,int(z[:,index]))] + [*range(int(z[:,index])+1,self.factor_sizes[index])]
            if r == []:
                continue 
            z[:,index] = np.random.choice(r)
            idx = index
        return z,idx
 
    def _generator(self,return_label=False):
        """ Generate images pairs. """
        random_state = np.random.RandomState(seed=self.random_seed)
        while True:
            first_factors = self.sample_latent(1)
            first_factors_index = self.latent_to_index(first_factors)
            first_image = self.data_tensor[first_factors_index]
            
            next_factors,indexs = self.generate_another(first_factors,random_state,k=self.change_latent_factor)
            next_factors_index = self.latent_to_index(next_factors) 
            next_image = self.data_tensor[next_factors_index]
            labels = indexs

            if return_label:
                yield torch.cat((first_image,next_image),0), labels
            else:
                yield torch.cat((first_image,next_image),0)

    
    def latent_to_index(self,latents):
        return np.dot(latents,self.factor_bases).astype(int)
    
    def sample_latent(self,size=1):
        samples = np.zeros((size,self.factor_sizes.size))
        for lat_i,lat_size in enumerate(self.factor_sizes):
            samples[:,lat_i] = np.random.randint(lat_size,size=size)
        return samples


def return_data(args, change_latent_factor=1):
    """ Train dataloader. Only support traffic sign data and attack data (stopsign) right now. """
    dataset = args.dataset
    image_size = args.image_size
    batch_size = args.batch_size
    dset_dir = args.dset_dir
    num_workers = args.num_workers

    if dataset.lower() == 'traffic':
        root = os.path.join(dset_dir, 'train')
        data_path = os.path.join(root, 'traffic_8000_3x128x128.npz')
        data_zip = np.load(data_path)
        imgs = data_zip['imgs']
        factor_sizes = data_zip['latent_sizes']
        factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
        imgs_tensor = torch.from_numpy(imgs).float()
        # dset = CustomIterableTensorDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)
        dset = CustomPairwiseDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)
        
        
    elif dataset.lower() == 'trafficv2':
        root = os.path.join(dset_dir, 'train')
        data_path = os.path.join(root, 'traffic_8x10x10x10_3x128x128.npz')
        data_zip = np.load(data_path)
        imgs = data_zip['imgs']
        factor_sizes = data_zip['latent_sizes']
        factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
        imgs_tensor = torch.from_numpy(imgs).float()
        # dset = CustomIterableTensorDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)
        dset = CustomPairwiseDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)


    elif dataset.lower() == 'stopsign':
        root = os.path.join(dset_dir, 'val')
        data_path = os.path.join(root, 'traffic_8000_3x128x128.npz')
        data_zip = np.load(data_path)
        imgs = data_zip['imgs']
        factor_sizes = data_zip['latent_sizes']
        factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
        imgs_tensor = torch.from_numpy(imgs).float()
        # dset = CustomIterableTensorDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)
        dset = CustomPairwiseDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)


    elif dataset.lower() =='dsprites':
        data_zip = np.load('./dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',encoding='latin1',allow_pickle=True)
        imgs = data_zip['imgs']
        metadata = data_zip['metadata'][()]
        factor_sizes = np.array(metadata['latents_sizes'],dtype=np.int64)
        factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
        labels = data_zip['latents_values']
        imgs_tensor = torch.from_numpy(imgs).unsqueeze(1).float()
        # print(f'Checkpoint dsprites.imgs_tensor: {imgs_tensor.shape}')
        # print(f'Checkpoint dsprites.factor_sizes: {factor_sizes}')
        # print(f'Checkpoint dsprites.factor_bases: {factor_bases}')
        # print(imgs_tensor.shape) # torch.Size([737280, 1, 64, 64])
        dset = CustomPairwiseDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)

    elif dataset.lower() == 'real_data':
        root = os.path.join(dset_dir, 'train')
        data_path = os.path.join(root, 'real_data_1960.npz')
        data_zip = np.load(data_path)
        imgs = data_zip['imgs']
        factor_sizes = data_zip['latent_sizes'] # [35, 8, 8]
        factor_bases = np.prod(factor_sizes) / np.cumprod(factor_sizes)
        imgs_tensor = torch.from_numpy(imgs).float()
        # dset = CustomIterableTensorDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)
        dset = CustomPairwiseDataset(imgs_tensor,factor_sizes,factor_bases,change_latent_factor=change_latent_factor)

    else: 
        print("wrong data folder names!!")
        raise NotImplementedError

    train_data = dset

    ## For PairwiseDataset
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)
    data_loader = train_loader
    return data_loader


class CustomPairwiseDataset(Dataset):
    """ Dataset for pairwise inputs. """
    def __init__(self,data_tensor,factor_sizes,factor_bases,change_latent_factor=1,random_seed=None):
        self.data_tensor = data_tensor
        self.factor_bases = factor_bases
        self.factor_sizes = factor_sizes
        self.num_factors = factor_sizes.size
        self.change_latent_factor = change_latent_factor
        if random_seed == None:
            self.random_seed = np.random.choice(10000)

    def __len__(self):
        return self.data_tensor.size(0)
    
    def generate_another(self,z,random_state,k=-1):
        if k == -1:
            k_observed = random_state.randint(1,self.num_factors)
        else:
            k_observed = k
        # Choose random latent space
        index_list = random_state.choice(z.shape[1],k_observed,replace=False)
        # input()

        idx = -1
        for index in index_list:
            r = [*range(0,int(z[:,index]))] + [*range(int(z[:,index])+1,self.factor_sizes[index])]
            if r == []:
                continue 
            z[:,index] = np.random.choice(r)
            idx = index
        return z,idx
    
    def latent_to_index(self, latents):
        return np.dot(latents,self.factor_bases).astype(int)

    def index_to_latent(self, index):
        latents = []
        for i in self.factor_bases:
            latents.append(int(index // i))
            index %= i
        return latents
    
    def sample_latent(self, size=1):
        samples = np.zeros((size,self.factor_sizes.size))
        for lat_i,lat_size in enumerate(self.factor_sizes):
            samples[:,lat_i] = np.random.randint(lat_size,size=size)
        # print("Samples:", samples.shape)
        return samples
    
    def __getitem__(self, index):
        ## Dataset specific
        # Get indices of sample - traffic datasets
        # 5000 = 100 rotate * 5 shape * 10 color, img[color][shape][rotate]
        return_label=True
        random_state = np.random.RandomState(seed=self.random_seed)

        # The first image of the pair
        first_factors = self.index_to_latent(index)
        first_factors_index = np.array([index])
        print("first latent:", first_factors)
        first_image = self.data_tensor[first_factors_index]
        
        # The second image of the pair
        next_factors,indexs = self.generate_another(first_factors,random_state,k=self.change_latent_factor)
        print("second latent:", next_factors)
        next_factors_index = self.latent_to_index(next_factors)
        next_image = self.data_tensor[next_factors_index]
        labels = indexs
        if return_label:
            return torch.cat((first_image,next_image),0), labels
        else:
            return torch.cat((first_image,next_image),0)


class CustomSemiDataset(Dataset):
    def __init__(self, data_tensor,factor_sizes,factor_bases,change_latent_factor=1,p=0.2,random_seed=None):
        self.data_tensor = data_tensor                    # img data, size = [8000, 3, 128, 128]
        self.factor_bases = factor_bases                  # for computing latent<->index = [800., 100.,   1.]
        self.factor_sizes = factor_sizes                  # size of latent factors = [10, 8, 100]
        self.num_factors = factor_sizes.size              # number of factor = 3
        self.change_latent_factor = change_latent_factor  # number of factor to be changed = 1
        self.semi_percent = p                             # Semi-supervised data percentage = 20%
        if random_seed == None:
            self.random_seed = np.random.choice(10000)    # Randomizer
        else:
            self.random_seed = random_seed
    
        self.index_map = self.get_index_map(k=self.change_latent_factor)
        
    def __len__(self):
        return self.data_tensor.size(0)
    
    def latent_to_index(self,latents):
        return np.dot(latents,self.factor_bases).astype(int)

    def index_to_latent(self, index):
        latents = []
        for i in self.factor_bases:
            latents.append(int(index // i))
            index %= i
        return latents
    
    def get_index_map(self, k=-1):
        random_state = np.random.RandomState(seed=self.random_seed)
        
        if k == -1:
            k_observed = random_state.randint(1, self.num_factors)
        elif k >= 0:
            k_observed = k
        else:
            raise NotImplementedError
        
        num_images = self.data_tensor.size(0)
        semi_indices = random.sample(range(num_images), int(self.semi_percent * num_images)) # Choose p * self.num_images unique indices
        input_map = np.zeros((num_images, 2, self.num_factors))
        
        for i in range(num_images):
            if i in semi_indices:
                # Semi-supervised pair
                first_index = i
                first_latents = self.index_to_latent(first_index)
                
                index_list = random_state.choice(len(first_latents),k_observed,replace=False)
                
                second_latents = first_latents.copy()
                for index in index_list:
                    r = [*range(0,int(first_latents[index]))] + [*range(int(first_latents[index])+1, self.factor_sizes[index])]
                    if r == []:     # if this latent feature size = 0
                        continue
                    second_latents[index] = random_state.choice(r)
                second_index = self.latent_to_index(second_latents)
                
            else:
                # Self-contrastive pair
                first_index = i
                first_latents = self.index_to_latent(first_index)
                second_index = i
                second_latents = first_latents
            
#             print('first image:', first_index,first_latents)
#             print('second image:', second_index,second_latents)
#             input()
            input_map[i][0] = first_latents
            input_map[i][1] = second_latents
        
        return input_map
        
    
    def __getitem__(self, index, return_label=True):
        first_latents = self.index_map[index][0]
        first_index = self.latent_to_index(first_latents)
        second_latents = self.index_map[index][1]
        second_index = self.latent_to_index(second_latents)
        labels = np.flatnonzero(first_latents != second_latents)

        if return_label:
            return torch.cat((torch.unsqueeze(self.data_tensor[first_index], 0),
                              torch.unsqueeze(self.data_tensor[second_index], 0)), 0), len(labels)
        else:
            return torch.cat((torch.unsqueeze(self.data_tensor[first_index], 0),
                              torch.unsqueeze(self.data_tensor[second_index], 0)), 0)


class CustomImageFolder(ImageFolder):
    """ Dataset for test data """
    def __init__(self, root, transform=None,filename=None):
        super(CustomImageFolder, self).__init__(root, transform, filename)
        self.df = pd.read_csv(filename, index_col=0)

    def __getitem__(self, index):
        path = self.imgs[index][0]
        # image_index = path.split('/',2)[-1]
        image_index = os.path.join(path.split('traffic/')[-1].split('/')[-2],
                                   path.split('traffic/')[-1].split('/')[-1])

        class_label = self.df.loc[image_index, 'class_label']
        shape_label = self.df.loc[image_index, 'shape_label']
        color_label = self.df.loc[image_index, 'color_label']
        rotate_label = self.df.loc[image_index, 'rotate_label']

        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        
        return img, class_label, shape_label, color_label, rotate_label
    
    
class CustomImageFolderReal(ImageFolder):
    """ Dataset for test data """
    def __init__(self, root, transform=None,filename=None):
        super(CustomImageFolderReal, self).__init__(root, transform, filename)
        self.df = pd.read_csv(filename, index_col=0)

    def __getitem__(self, index):
        path = self.imgs[index][0]
        # image_index = path.split('/',2)[-1]
        image_index = os.path.join(path.split('real_data/')[-1].split('/')[-2],
                                   path.split('real_data/')[-1].split('/')[-1])

        class_label = self.df.loc[image_index, 'class_label']
        shape_label = self.df.loc[image_index, 'shape_label']
        color_label = self.df.loc[image_index, 'color_label']

        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        
        return img, class_label, shape_label, color_label


class CustomTensorDataset(Dataset):
    def __init__(self, data_tensor):
        self.data_tensor = data_tensor

    def __getitem__(self, index):
        return self.data_tensor[index]

    def __len__(self):
        return self.data_tensor.size(0)


def return_test_data(args):
    """ Test dataloader. """
    dataset = args.dataset
    batch_size = args.batch_size
    num_workers = args.num_workers
    image_size = args.image_size
    dset_dir = args.dset_dir
    
    if dataset.lower() == 'traffic':
        print("Generate traffic sign data.")
        root = os.path.join(dset_dir, 'train')
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(root, 'class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolder
        
    elif dataset.lower() == 'trafficv2':
        print("Generate traffic sign data (version 8x10x10x10).")
        root = os.path.join(dset_dir, 'train')
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(root, 'class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolder

    elif dataset.lower() == 'attack':
        print("Generate attack data.")
        root = os.path.join(dset_dir, 'val')
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(root, 'attack_class_label_all.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolder

    elif dataset.lower() == 'dsprites':
        print("Test data for dsprites.")
        root = os.path.join(dset_dir, 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
        data = np.load(root, encoding='bytes')
        data = torch.from_numpy(data['imgs']).unsqueeze(1).float()
        train_kwargs = {'data_tensor':data}
        dset = CustomTensorDataset
        
    elif dataset.lower() == '3d_shapes':
        print("Test data for 3d shapes.")
        data_path = os.path.join(dset_dir, '3dshapes.h5')
        data = h5py.File(data_path,'r+')
        data = torch.from_numpy(data['images'][()].transpose(0, 3, 1, 2)).float()
        train_kwargs = {'data_tensor': data}
        # print("Test data:", data.shape)
        dset = CustomTensorDataset
        
    elif dataset.lower() == '3d_chairs':
        print("Test data for 3d chairs.")
        data_path = '/data2/qidi/Research/ControlVAE-ICML2020/Disentangling/data/3dchairs.npz'
        data_zip = np.load(data_path)
        data = torch.from_numpy(data_zip['imgs'].transpose(0, 3, 1, 2)).float()
        train_kwargs = {'data_tensor': data}
        dset = CustomTensorDataset
    
    elif dataset.lower() == 'real_data':
        print("Generate real-world traffic sign data.")
        root = dset_dir
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(root, 'class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = CustomImageFolderReal
    
    else:
        print("wrong data folder names!!")
        raise NotImplementedError
    
    train_data = dset(**train_kwargs)
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)
    data_loader = train_loader
    
    return data_loader





class AttackImageFolder(ImageFolder):
    """ Dataset for test data """
    def __init__(self, root, transform=None,filename=None):
        super(AttackImageFolder, self).__init__(root, transform, filename)
        self.df = pd.read_csv(filename, index_col=0)
        # self.df.reset_index(drop=True)
        # print(self.df)
    
    def __getitem__(self, index):
        path = self.imgs[index][0]
        # image_index = path.split('/',2)[-1]
        image_index = os.path.join(path.split('traffic/')[-1].split('/')[-2],
                                   path.split('traffic/')[-1].split('/')[-1])

        shape_label = self.df.loc[image_index, 'shape_label']
        color_label = self.df.loc[image_index, 'color_label']
        class_label = self.df.loc[image_index, 'class_label']

        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        
        return img, class_label, shape_label, color_label


def return_attack_data(args):
    dataset = args.dataset
    batch_size = args.batch_size
    num_workers = args.num_workers
    image_size = args.image_size
    dset_dir = args.dset_dir

    attack_types = ['all', 'elastic', 'g_blur', 'g_noise', 'splatter', 'sticker']

    if dataset.lower() == 'traffic':
        print("Generate traffic data.")
        root = os.path.join(dset_dir, 'train')
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(root, 'class_label.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = AttackImageFolder

    elif dataset.lower() == 'attack':
        print("Generate attack data.")
        root = os.path.join(dset_dir, 'val')
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(root, 'attack_class_label_all.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = AttackImageFolder

    elif dataset.lower() in attack_types:
        attack = dataset.lower()
        print(f"Generate {attack} attack data.")
        root = os.path.join(dset_dir, f'val/{attack}')
        transform = transforms.Compose([transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),])
        filename = os.path.join(root, f'attack_class_label_{attack}.csv')
        train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
        dset = AttackImageFolder

    else:
        return

    attack_data = dset(**train_kwargs)
    attack_loader = DataLoader(attack_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)
    data_loader = attack_loader
    
    return data_loader


def th_delete(tensor, indices):
    # Delete the indices-th element.
    mask = torch.ones(tensor.numel(), dtype=torch.bool)
    mask[indices] = False
    return tensor[mask]


def choose_swap_dims(kl_per_indices, disentangled_dims, unchanged_latent_indices=1, threshold=0.5):
    # Find the maximum value location of kl for every sample.
    # Labels has the same shape with kl_per_indices (batch_size x latent_length).
    dim_z = kl_per_indices.shape[1]
    indices = kl_per_indices
    
    top_k_indices_value, top_k_indices = torch.topk(indices,k=unchanged_latent_indices,dim=1)
    
    ## filter 2 - Modify top_k_indices
    mask = top_k_indices_value > threshold
    mat1 = torch.repeat_interleave(top_k_indices, len(disentangled_dims), dim=1)
    mat2 = Variable(cuda(torch.tensor(disentangled_dims), True)).repeat(indices.shape[0], 1)
    top_k_indices = torch.where(mask, mat2, mat1)
    # print(disentangled_dims)
    # print(mask)
    # print(torch.where(mask, mat2, mat1))
    # input()

    matrix = Variable(cuda(torch.arange(indices.shape[1]), True))
    unchanged_indices = th_delete(matrix, disentangled_dims).repeat(indices.shape[0], 1)
    unchanged_indices = torch.cat((unchanged_indices, top_k_indices), dim=1)
    unchanged_indices = torch.unique(unchanged_indices, dim=1)
    # print(matrix)
    # print(top_k_indices_value)
    # print(top_k_indices)
    # print(unchanged_indices)
    # input()


    labels = Variable(cuda(torch.zeros_like(kl_per_indices).type(torch.ByteTensor), True))
    for i in range(unchanged_indices.shape[1]):
        current_indices = unchanged_indices[:,i]

        matrix1 = Variable(cuda(torch.arange(indices.shape[1]).repeat(indices.shape[0],1), True))
        matrix2 = Variable(cuda(torch.unsqueeze(current_indices,1), True))
        # print(matrix1)
        # input()
        # print(matrix2)
        # input()
        current_label = torch.eq(matrix1, matrix2)
        # print(labels)
        # input()
        # print(current_label)
        # input()
        # current_label = torch.eq(torch.arange(indices.shape[1]).repeat(indices.shape[0],1).to(), 
        #                          torch.unsqueeze(current_indices,1))
        labels = torch.logical_or(labels, current_label)
        # print('i:', i)
        # print('matrix1:', matrix1[0])
        # print('matrix2:', matrix2[0])
        # print(labels[0])
        # print(current_label[0])
    return labels


def change_latent_space(mu1,mu2,logvar1,logvar2,labels,disentangled_dims,unchanged_latent_indices=1,threshold=0.5):
    z_dim = labels.shape[1]

    indices = choose_swap_dims(labels, disentangled_dims=disentangled_dims, 
                               unchanged_latent_indices=unchanged_latent_indices,
                               threshold=threshold)
    
    # print("Swap indices:", indices)
    # input()

    # Change the latent space value except the latant element which has maximum information value
    mu1_new = torch.where(indices,mu1,mu2)
    logvar1_new = torch.where(indices,logvar1,logvar2)

    mu2_new = torch.where(indices,mu2,mu1)
    logvar2_new = torch.where(indices,logvar2,logvar1)
    
    return mu1_new,logvar1_new,mu2_new,logvar2_new


def process_kl(path_root):
    # Speficy the path of 'train.kl' file, which records the KL divergence over in training
    # Covert KL values to csv file
    kl_path = os.path.join(path_root, 'train.kl')
    csv_path = os.path.join(path_root, 'train.csv')

    column_names = ['dim_0',
                    'dim_1',
                    'dim_2',
                    'dim_3',
                    'dim_4',
                    'dim_5',
                    'dim_6',
                    'dim_7',
                    'dim_8',
                    'dim_9']
    kls = []

    ## Read in kl values
    with open(kl_path, 'r', encoding='UTF8') as klfile:
        for line in klfile:
            dim_kls = line.split(',')
            dim_kls[0] = dim_kls[0].split(':')[-1]
            dim_kls[-1] = dim_kls[-1].split('\n')[0]
            dim_kls = [float(i) for i in dim_kls]
            kls.append(dim_kls)

    ## Write row into csv
    with open(csv_path, 'w', newline='') as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(column_names)
        for row in kls:
            csv_writer.writerow(row)



if __name__ == '__main__':
    # parser = argparse.ArgumentParser(description='getKLValues')
    # parser.add_argument('--checkpoint_root', default='./checkpoints/Traffic_128_c11_0.15_semi0.2p_dataSimonv2', type=str)

    # args = parser.parse_args()
    # checkpoint_root = args.checkpoint_root
    # process_kl(checkpoint_root)

    dset_dir = '/data/open-datasets/traffic'
    root = os.path.join(dset_dir, 'train')
    transform = transforms.Compose([transforms.Resize((128, 128)),
    transforms.ToTensor(),])
    filename = os.path.join(root, 'class_label.csv')
    train_kwargs = {'root':root, 'transform':transform, 'filename': filename}
    dset = CustomImageFolder(**train_kwargs)
    dset.__getitem__(index=0)