import torch
import cv2
import numpy as np
from tqdm import tqdm
import torch
import torchvision as tv
import normflows as nf
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import pandas as pd


def compute_complexity(img):
    data = torch.permute(img, (1, 2, 0)).numpy()
    data_int = data * 255
    data_int = data_int.astype(np.uint8)
    encode_param=[int(cv2.IMWRITE_PNG_COMPRESSION), 9]
    _,encimg=cv2.imencode(".png", data_int, encode_param)
    return len(encimg)


def log_prob(z):
    return -(0.5*torch.sum(torch.pow(z,2),dim=1) + z.size(1)*0.5*torch.log(torch.tensor(2*np.pi)))


def dataload(dataset, batch_size,input_size):
    transform = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Resize((input_size,input_size)), nf.utils.Jitter(1 / 256.) ])    
    train_tensor = []
    test_tensor = []
    complexity_train = []
    complexity_test = []
    if dataset == "CIFAR10":
        train_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform)
    elif dataset == "CIFAR100":
        train_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=False,download=True, transform=transform)
    elif dataset == "SVHN":
        train_data = tv.datasets.SVHN(root='./dataset/svhn', split='train', download=True, transform=transform)
        test_data = tv.datasets.SVHN(root='./dataset/svhn', split='test',download=True, transform=transform)
    elif dataset == "MNIST":
        train_data = tv.datasets.MNIST(root='./dataset/mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.MNIST(root='./dataset/mnist', train=False, download=True, transform=transform)
    elif dataset == 'FashionMNIST':
        train_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=False, download=True, transform=transform)
    elif dataset =='celebA':
        train_data = tv.datasets.CelebA(root='./dataset/celebA', split='train', download=True, transform=transform)
        test_data = tv.datasets.CelebA(root='./dataset/celebA', split='test', download=True, transform=transform)
    elif dataset =='omniglot':
        train_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=True, download=True, transform=transform)
        test_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=False, download=True, transform=transform)
    print(f"# of Train Data : {len(train_data)}")
    print(f"# of Test Data : {len(test_data)}")
    if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
        for i, data in tqdm(enumerate(train_data,0)):
            repeated_data = data[0].repeat(3,1,1)
            train_tensor.append(repeated_data)
        
            if i==0:
                print(repeated_data.shape)
            complexity_train.append(compute_complexity(repeated_data))

        for i, data in tqdm(enumerate(test_data,0)):
            repeated_data = data[0].repeat(3,1,1)
            test_tensor.append(repeated_data)
            #complexity_test.append(compute_complexity(repeated_data))
            #test_tensor.append(data[0])
            complexity_test.append(compute_complexity(repeated_data))

    elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
        for i, data in tqdm(enumerate(train_data,0)):
            train_tensor.append(data[0])
            if i==0:
                print(data[0].shape)
            complexity_train.append(compute_complexity(data[0]))
        for i, data in tqdm(enumerate(test_data,0)):
            test_tensor.append(data[0])
            complexity_test.append(compute_complexity(data[0]))
            
    train_loader = torch.utils.data.DataLoader(train_tensor, batch_size=batch_size, shuffle=True, drop_last = False)
    test_loader = torch.utils.data.DataLoader(test_tensor, batch_size=batch_size, drop_last = False)
    return complexity_train, complexity_test, train_loader, test_loader,
    
def dataload_background_perturb(dataset, batch_size,input_size):
    transform = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Resize((input_size,input_size)) ])    
    train_tensor = []
    test_tensor = []
    complexity_train = []
    complexity_test = []
    if dataset == "CIFAR10":
        train_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform)
    elif dataset == "CIFAR100":
        train_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=False,download=True, transform=transform)
    elif dataset == "SVHN":
        train_data = tv.datasets.SVHN(root='./dataset/svhn', split='train', download=True, transform=transform)
        test_data = tv.datasets.SVHN(root='./dataset/svhn', split='test',download=True, transform=transform)
    elif dataset == "MNIST":
        train_data = tv.datasets.MNIST(root='./dataset/mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.MNIST(root='./dataset/mnist', train=False, download=True, transform=transform)
    elif dataset == 'FashionMNIST':
        train_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=False, download=True, transform=transform)
    elif dataset =='celebA':
        train_data = tv.datasets.CelebA(root='./dataset/celebA', split='train', download=True, transform=transform)
        test_data = tv.datasets.CelebA(root='./dataset/celebA', split='test', download=True, transform=transform)
    elif dataset =='omniglot':
        train_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=True, download=True, transform=transform)
        test_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=False, download=True, transform=transform)
    print(f"# of Train Data : {len(train_data)}")
    print(f"# of Test Data : {len(test_data)}")
    if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
        for i, data in tqdm(enumerate(train_data,0)):
            repeated_data = data[0].repeat(3,1,1)
            mask = torch.rand_like(repeated_data.float()) < 0.2
            random_values = torch.randint(0, 256, size=repeated_data.shape, dtype=repeated_data.dtype) / 255
            repeated_data[mask] = random_values[mask]
            train_tensor.append(repeated_data)
            if i==0:
                print(repeated_data.shape)
            complexity_train.append(compute_complexity(repeated_data))

        for i, data in tqdm(enumerate(test_data,0)):
            repeated_data = data[0].repeat(3,1,1)
            test_tensor.append(repeated_data)
            #complexity_test.append(compute_complexity(repeated_data))
            #test_tensor.append(data[0])
            complexity_test.append(compute_complexity(repeated_data))

    elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
        for i, data in tqdm(enumerate(train_data,0)):
            mask = torch.rand_like(data[0].float()) < 0.2
            random_values = torch.randint(0, 256, size=data[0].shape, dtype=data[0].dtype) / 255
            data[0][mask] = random_values[mask]
            train_tensor.append(data[0])
            if i==0:
                print(data[0].shape)
            complexity_train.append(compute_complexity(data[0]))
        for i, data in tqdm(enumerate(test_data,0)):
            test_tensor.append(data[0])
            complexity_test.append(compute_complexity(data[0]))
            
    train_loader = torch.utils.data.DataLoader(train_tensor, batch_size=batch_size, shuffle=True, drop_last = False)
    test_loader = torch.utils.data.DataLoader(test_tensor, batch_size=batch_size, drop_last = False)
    return complexity_train, complexity_test, train_loader, test_loader,
    


def dataload_flatten(dataset, batch_size,input_size):
    transform = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Resize((input_size,input_size)), nf.utils.Jitter(1 / 256.) ])    
    train_tensor = []
    test_tensor = []
    complexity_train = []
    complexity_test = []
    if dataset == "CIFAR10":
        train_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform)
    elif dataset == "CIFAR100":
        train_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=False,download=True, transform=transform)
    elif dataset == "SVHN":
        train_data = tv.datasets.SVHN(root='./dataset/svhn', split='train', download=True, transform=transform)
        test_data = tv.datasets.SVHN(root='./dataset/svhn', split='test',download=True, transform=transform)
    elif dataset == "MNIST":
        train_data = tv.datasets.MNIST(root='./dataset/mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.MNIST(root='./dataset/mnist', train=False, download=True, transform=transform)
    elif dataset == 'FashionMNIST':
        train_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=False, download=True, transform=transform)
    elif dataset =='celebA':
        train_data = tv.datasets.CelebA(root='./dataset/celebA', split='train', download=True, transform=transform)
        test_data = tv.datasets.CelebA(root='./dataset/celebA', split='test', download=True, transform=transform)
    elif dataset =='omniglot':
        train_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=True, download=True, transform=transform)
        test_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=False, download=True, transform=transform)
    print(f"# of Train Data : {len(train_data)}")
    print(f"# of Test Data : {len(test_data)}")
    if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
        for i, data in tqdm(enumerate(train_data,0)):
            #repeated_data = data[0].repeat(3,1,1)
            #train_tensor.append(repeated_data)
            #complexity_train.append(compute_complexity(repeated_data))
            train_tensor.append(data[0].reshape(-1))
            if i==0:
                print(data[0].reshape(-1).shape)
            complexity_train.append(compute_complexity(data[0]))

        for i, data in tqdm(enumerate(test_data,0)):
            #repeated_data = data[0].repeat(3,1,1)
            #test_tensor.append(repeated_data)
            #complexity_test.append(compute_complexity(repeated_data))
            test_tensor.append(data[0].reshape(-1))
            if i==0:
                print(data[0].reshape(-1).shape)
            complexity_test.append(compute_complexity(data[0]))

    elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
        for i, data in tqdm(enumerate(train_data,0)):
            train_tensor.append(data[0].reshape(-1))
            if i==0:
                print(data[0].reshape(-1).shape)
            complexity_train.append(compute_complexity(data[0]))     
        for i, data in tqdm(enumerate(test_data,0)):
            test_tensor.append(data[0].reshape(-1))
            if i==0:
                print(data[0].reshape(-1).shape)
            complexity_test.append(compute_complexity(data[0]))
    train_tensor = torch.stack(train_tensor)
    test_tensor = torch.stack(test_tensor)
    return complexity_train, complexity_test, train_tensor, test_tensor

def dataload_gray(dataset, batch_size, input_size):
    transform = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Resize((input_size,input_size)), tv.transforms.Grayscale(1), nf.utils.Jitter(1 / 256.) ])    
    train_tensor = []
    test_tensor = []
    complexity_train = []
    complexity_test = []
    if dataset == "CIFAR10":
        train_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform)
    elif dataset == "CIFAR100":
        train_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=False,download=True, transform=transform)
    elif dataset == "SVHN":
        train_data = tv.datasets.SVHN(root='./dataset/svhn', split='train', download=True, transform=transform)
        test_data = tv.datasets.SVHN(root='./dataset/svhn', split='test',download=True, transform=transform)
    elif dataset == "MNIST":
        train_data = tv.datasets.MNIST(root='./dataset/mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.MNIST(root='./dataset/mnist', train=False, download=True, transform=transform)
    elif dataset == 'FashionMNIST':
        train_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=False, download=True, transform=transform)
    elif dataset =='celebA':
        train_data = tv.datasets.CelebA(root='./dataset/celebA', split='train', download=True, transform=transform)
        test_data = tv.datasets.CelebA(root='./dataset/celebA', split='test', download=True, transform=transform)
    elif dataset =='omniglot':
        train_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=True, download=True, transform=transform)
        test_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=False, download=True, transform=transform)
    print(f"# of Train Data : {len(train_data)}")
    print(f"# of Test Data : {len(test_data)}")
    if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
        for i, data in tqdm(enumerate(train_data,0)):
            #repeated_data = data[0].repeat(3,1,1)
            #train_tensor.append(repeated_data)
            #complexity_train.append(compute_complexity(repeated_data))
            train_tensor.append(data[0].reshape(-1))
            if i==0:
                print(data[0].reshape(-1).shape)
            complexity_train.append(compute_complexity(data[0]))

        for i, data in tqdm(enumerate(test_data,0)):
            #repeated_data = data[0].repeat(3,1,1)
            #test_tensor.append(repeated_data)
            #complexity_test.append(compute_complexity(repeated_data))
            test_tensor.append(data[0].reshape(-1))
            if i==0:
                print(data[0].reshape(-1).shape)
            complexity_test.append(compute_complexity(data[0]))

    elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
        for i, data in tqdm(enumerate(train_data,0)):
            train_tensor.append(data[0].reshape(-1))
            if i==0:
                print(data[0].reshape(-1).shape)
            complexity_train.append(compute_complexity(data[0]))     

        for i, data in tqdm(enumerate(test_data,0)):
            test_tensor.append(data[0].reshape(-1))
            if i==0:
                print(data[0].reshape(-1).shape)
            complexity_test.append(compute_complexity(data[0]))

    train_tensor = torch.stack(train_tensor)
    test_tensor = torch.stack(test_tensor)
    return complexity_train, complexity_test, train_tensor, test_tensor

def dataload_gray_2d(dataset, batch_size,input_size):
    transform = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Resize((input_size,input_size)), tv.transforms.Grayscale(1), nf.utils.Jitter(1 / 256.) ])    
    train_tensor = []
    test_tensor = []
    complexity_train = []
    complexity_test = []
    if dataset == "CIFAR10":
        train_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform)
    elif dataset == "CIFAR100":
        train_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=False,download=True, transform=transform)
    elif dataset == "SVHN":
        train_data = tv.datasets.SVHN(root='./dataset/svhn', split='train', download=True, transform=transform)
        test_data = tv.datasets.SVHN(root='./dataset/svhn', split='test',download=True, transform=transform)
    elif dataset == "MNIST":
        train_data = tv.datasets.MNIST(root='./dataset/mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.MNIST(root='./dataset/mnist', train=False, download=True, transform=transform)
    elif dataset == 'FashionMNIST':
        train_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=False, download=True, transform=transform)
    elif dataset =='celebA':
        train_data = tv.datasets.CelebA(root='./dataset/celebA', split='train', download=True, transform=transform)
        test_data = tv.datasets.CelebA(root='./dataset/celebA', split='test', download=True, transform=transform)
    elif dataset =='omniglot':
        train_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=True, download=True, transform=transform)
        test_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=False, download=True, transform=transform)
    print(f"# of Train Data : {len(train_data)}")
    print(f"# of Test Data : {len(test_data)}")
    if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
        for i, data in tqdm(enumerate(train_data,0)):
            #repeated_data = data[0].repeat(3,1,1)
            #train_tensor.append(repeated_data)
            #complexity_train.append(compute_complexity(repeated_data))
            train_tensor.append(data[0])
            if i==0:
                print(data[0].shape)
            complexity_train.append(compute_complexity(data[0]))

        for i, data in tqdm(enumerate(test_data,0)):
            #repeated_data = data[0].repeat(3,1,1)
            #test_tensor.append(repeated_data)
            #complexity_test.append(compute_complexity(repeated_data))
            test_tensor.append(data[0])
            complexity_test.append(compute_complexity(data[0]))

    elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
        for i, data in tqdm(enumerate(train_data,0)):
            train_tensor.append(data[0])
            if i==0:
                print(data[0].shape)
            complexity_train.append(compute_complexity(data[0]))
        for i, data in tqdm(enumerate(test_data,0)):
            test_tensor.append(data[0])
            complexity_test.append(compute_complexity(data[0]))
    
            
    train_loader = torch.utils.data.DataLoader(train_tensor, batch_size=batch_size, shuffle=True, drop_last = False)
    test_loader = torch.utils.data.DataLoader(test_tensor, batch_size=batch_size, drop_last = False)
    return complexity_train, complexity_test,  train_loader, test_loader, 

def dataload_gray_binary(dataset, batch_size, input_size):
    transform = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Resize((input_size,input_size)), tv.transforms.Grayscale(1), nf.utils.Jitter(1 / 256.) ])    
    train_tensor = []
    test_tensor = []
    complexity_train = []
    complexity_test = []
    if dataset == "CIFAR10":
        train_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform)
    elif dataset == "CIFAR100":
        train_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=False,download=True, transform=transform)
    elif dataset == "SVHN":
        train_data = tv.datasets.SVHN(root='./dataset/svhn', split='train', download=True, transform=transform)
        test_data = tv.datasets.SVHN(root='./dataset/svhn', split='test',download=True, transform=transform)
    elif dataset == "MNIST":
        train_data = tv.datasets.MNIST(root='./dataset/mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.MNIST(root='./dataset/mnist', train=False, download=True, transform=transform)
    elif dataset == 'FashionMNIST':
        train_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=False, download=True, transform=transform)
    elif dataset =='celebA':
        train_data = tv.datasets.CelebA(root='./dataset/celebA', split='train', download=True, transform=transform)
        test_data = tv.datasets.CelebA(root='./dataset/celebA', split='test', download=True, transform=transform)
    elif dataset =='omniglot':
        train_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=True, download=True, transform=transform)
        test_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=False, download=True, transform=transform)
    print(f"# of Train Data : {len(train_data)}")
    print(f"# of Test Data : {len(test_data)}")
    if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
        for i, data in tqdm(enumerate(train_data,0)):
            #repeated_data = data[0].repeat(3,1,1)
            #train_tensor.append(repeated_data)
            #complexity_train.append(compute_complexity(repeated_data))
            data_processed = torch.where(data[0] > data[0].mean(), 1.0, 0.0)
            train_tensor.append(data_processed.reshape(-1))
            if i==0:
                print(data_processed.reshape(-1).shape)
            complexity_train.append(compute_complexity(data_processed))

        for i, data in tqdm(enumerate(test_data,0)):
            #repeated_data = data[0].repeat(3,1,1)
            #test_tensor.append(repeated_data)
            #complexity_test.append(compute_complexity(repeated_data))
            data_processed = torch.where(data[0] > data[0].mean(), 1.0, 0.0)
            test_tensor.append(data_processed.reshape(-1))
            if i==0:
                print(data_processed.reshape(-1).shape)
            complexity_test.append(compute_complexity(data_processed))

    elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
        for i, data in tqdm(enumerate(train_data,0)):
            data_processed = torch.where(data[0] > data[0].mean(), 1.0, 0.0)
            train_tensor.append(data_processed.reshape(-1))
            if i==0:
                print(data_processed.reshape(-1).shape)
            complexity_train.append(compute_complexity(data_processed))     

        for i, data in tqdm(enumerate(test_data,0)):
            data_processed = torch.where(data[0] > data[0].mean(), 1.0, 0.0)
            test_tensor.append(data_processed.reshape(-1))
            if i==0:
                print(data_processed.reshape(-1).shape)
            complexity_test.append(compute_complexity(data_processed))

    train_tensor = torch.stack(train_tensor)
    test_tensor = torch.stack(test_tensor)
    return complexity_train, complexity_test, train_tensor, test_tensor



def dataload_mean(dataset, batch_size, input_size,out_flag, mean_complexity):
    transform = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Resize((input_size,input_size)), nf.utils.Jitter(1 / 256.) ])    
    train_tensor = []
    test_tensor = []
    complexity_train = []
    complexity_test = []
    if dataset == "CIFAR10":
        train_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform)
    elif dataset == "CIFAR100":
        train_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=False,download=True, transform=transform)
    elif dataset == "SVHN":
        train_data = tv.datasets.SVHN(root='./dataset/svhn', split='train', download=True, transform=transform)
        test_data = tv.datasets.SVHN(root='./dataset/svhn', split='test',download=True, transform=transform)
    elif dataset == "MNIST":
        train_data = tv.datasets.MNIST(root='./dataset/mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.MNIST(root='./dataset/mnist', train=False, download=True, transform=transform)
    elif dataset == 'FashionMNIST':
        train_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=False, download=True, transform=transform)
    elif dataset =='celebA':
        train_data = tv.datasets.CelebA(root='./dataset/celebA', split='train', download=True, transform=transform)
        test_data = tv.datasets.CelebA(root='./dataset/celebA', split='test', download=True, transform=transform)
    elif dataset =='omniglot':
        train_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=True, download=True, transform=transform)
        test_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=False, download=True, transform=transform)
    print(f"# of Train Data : {len(train_data)}")
    print(f"# of Test Data : {len(test_data)}")
    if out_flag == False:
        if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
            for i, data in tqdm(enumerate(train_data,0)):
                #repeated_data = data[0].repeat(3,1,1)
                #train_tensor.append(repeated_data)
                #complexity_train.append(compute_complexity(repeated_data))
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                complexity_train.append(compute_complexity(data[0]))

            for i, data in tqdm(enumerate(test_data,0)):
                #repeated_data = data[0].repeat(3,1,1)
                #test_tensor.append(repeated_data)
                #complexity_test.append(compute_complexity(repeated_data))
                test_tensor.append(data[0])
                complexity_test.append(compute_complexity(data[0]))

        elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
            for i, data in tqdm(enumerate(train_data,0)):
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                complexity_train.append(compute_complexity(data[0]))
            mean_complexity = np.mean(complexity_train)
            for i, data in tqdm(enumerate(test_data,0)):
                complexity = compute_complexity(data[0])
                perturb = torch.randn((3,32,32)) * 0.01 * abs(mean_complexity-complexity)/100
                complexity_test.append(compute_complexity(torch.clamp(data[0]+perturb,0,1)))
                test_tensor.append(torch.clamp(data[0]+perturb,0,1))
        train_loader = torch.utils.data.DataLoader(train_tensor, batch_size=batch_size, shuffle=True, drop_last = False)
        test_loader = torch.utils.data.DataLoader(test_tensor, batch_size=batch_size, drop_last = False)
        return complexity_train, complexity_test, train_loader, test_loader, mean_complexity    
    else:
        if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
            for i, data in tqdm(enumerate(train_data,0)):
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                complexity_train.append(compute_complexity(data[0]))
                
            tmp = torch.stack(train_tensor)
            print(tmp.shape)
            mean = torch.zeros(tmp[0].shape)
            print(mean.shape)
            for i in range(10000):
                data = mean
                test_tensor.append(data)
                complexity_test.append(compute_complexity(data))

        elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
            for i, data in tqdm(enumerate(train_data,0)):
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                complexity_train.append(compute_complexity(data[0]))
                if i<1000:
                    print(compute_complexity(data[0]))
            tmp = torch.stack(train_tensor)
            print(tmp.shape)
            #mean = torch.mean(tmp, axis=0)
            mean = torch.full(tmp[0].shape, 1/255)
            print(mean)
            print(mean.shape)
            print(f"train PNG bit mean : {np.mean(complexity_train)}")
            '''
            for i in range(255):
                #data = mean + torch.randn((3,32,32)) * 0.01
                #data = torch.full((3,32,32), mean) + torch.randn((3,32,32)) * 0.001
                data = torch.full(tmp[0].shape, i/255) + abs(torch.randn((3,32,32)) * 0.02)
                test_tensor.append(data)
                complexity_test.append(compute_complexity(data))
                c = compute_complexity(data)
                print(c)
            '''
            for i, data in tqdm(enumerate(test_data,0)):
                #repeated_data = data[0].repeat(3,1,1)
                #test_tensor.append(repeated_data)
                #complexity_test.append(compute_complexity(repeated_data))
                complexity = compute_complexity(data[0])
                perturb = torch.randn((3,32,32)) * 0.01 * abs(mean_complexity-complexity)/100
                complexity_test.append(compute_complexity(torch.clamp(data[0]+perturb,0,1)))
                test_tensor.append(torch.clamp(data[0]+perturb,0,1))
                if i <1000:
                    print(compute_complexity(data[0]+perturb))
        train_loader = torch.utils.data.DataLoader(train_tensor, batch_size=batch_size, shuffle=True, drop_last = False)
        test_loader = torch.utils.data.DataLoader(test_tensor, batch_size=batch_size, drop_last = False)
        return complexity_train, complexity_test, train_loader, test_loader, mean_complexity


def dataload_perturb(dataset, batch_size, input_size , in_dist_flag):
    transform = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Resize((input_size,input_size)), nf.utils.Jitter(1 / 256.) ])    
    train_tensor = []
    test_tensor = []
    complexity_train = []
    complexity_test = []
    if dataset == "CIFAR10":
        train_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform)
    elif dataset == "CIFAR100":
        train_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=False,download=True, transform=transform)
    elif dataset == "SVHN":
        train_data = tv.datasets.SVHN(root='./dataset/svhn', split='train', download=True, transform=transform)
        test_data = tv.datasets.SVHN(root='./dataset/svhn', split='test',download=True, transform=transform)
    elif dataset == "MNIST":
        train_data = tv.datasets.MNIST(root='./dataset/mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.MNIST(root='./dataset/mnist', train=False, download=True, transform=transform)
    elif dataset == 'FashionMNIST':
        train_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=False, download=True, transform=transform)
    elif dataset =='celebA':
        train_data = tv.datasets.CelebA(root='./dataset/celebA', split='train', download=True, transform=transform)
        test_data = tv.datasets.CelebA(root='./dataset/celebA', split='test', download=True, transform=transform)
    elif dataset =='omniglot':
        train_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=True, download=True, transform=transform)
        test_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=False, download=True, transform=transform)
    print(f"# of Train Data : {len(train_data)}")
    print(f"# of Test Data : {len(test_data)}")
    '''
    if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
        for i, data in tqdm(enumerate(train_data,0)):
            train_tensor.append(data[0])
            if i==0:
                print(data[0].shape)
            complexity_train.append(compute_complexity(data[0]))
        tmp = torch.stack(train_tensor)
        print(tmp.shape)
        mean = torch.mean(tmp)
        print(mean.shape)
        for i in range(10000):
            data = mean + torch.randn((1,32,32)) * 0.01
            test_tensor.append(data)
            complexity_test.append(compute_complexity(data))

    elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
        for i, data in tqdm(enumerate(train_data,0)):
            train_tensor.append(data[0])
            if i==0:
                print(data[0].shape)
            complexity_train.append(compute_complexity(data[0]))
            if i > 1000:
                break
        tmp = torch.stack(train_tensor)
        print(tmp.shape)

        mean = torch.mean(tmp, axis=0)
        print(mean)
        print(mean.shape)
        for i in range(100):
            #data = mean + torch.randn((3,32,32)) * 0.01
            #data = torch.full((3,32,32), mean) + torch.randn((3,32,32)) * 0.001
            data = mean + torch.randn((3,32,32)) * 0.01
            test_tensor.append(data)
            complexity_test.append(compute_complexity(data))
            
        train_loader = torch.utils.data.DataLoader(train_tensor, batch_size=batch_size, shuffle=True, drop_last = False)
        test_loader = torch.utils.data.DataLoader(test_tensor, batch_size=batch_size, drop_last = False)
        return complexity_train, complexity_test, train_loader, test_loader, mean
    '''
    if in_dist_flag==True:
        if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
            for i, data in tqdm(enumerate(train_data,0)):
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                complexity_train.append(compute_complexity(data[0]))
                if i> 1000:
                    break
            tmp = torch.stack(train_tensor)
            print(tmp.shape)
            mean = torch.mean(tmp)
            print(mean.shape)
            for i in range(1000):
                data = mean + torch.randn((1,32,32)) * 0.003
                test_tensor.append(data)
                complexity_test.append(compute_complexity(data))

        elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
            for i, data in tqdm(enumerate(train_data,0)):
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                complexity_train.append(compute_complexity(data[0]))
                if i> 1000:
                    break
            tmp = torch.stack(train_tensor)
            print(tmp.shape)
            mean = torch.mean(tmp, axis=0)
            print(mean)
            print(mean.shape)
            for i in range(1000):
                #data = mean + torch.randn((3,32,32)) * 0.01
                #data = torch.full((3,32,32), mean) + torch.randn((3,32,32)) * 0.001
                data = mean + torch.randn((3,32,32)) * 0.003
                test_tensor.append(data)
                complexity_test.append(compute_complexity(data))
    else:
        if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
            for i, data in tqdm(enumerate(train_data,0)):
                #repeated_data = data[0].repeat(3,1,1)
                #train_tensor.append(repeated_data)
                #complexity_train.append(compute_complexity(repeated_data))
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                complexity_train.append(compute_complexity(data[0]))
                if i > 1000:
                    break
            for i, data in tqdm(enumerate(test_data,0)):
                #repeated_data = data[0].repeat(3,1,1)
                #test_tensor.append(repeated_data)
                #complexity_test.append(compute_complexity(repeated_data))
                test_tensor.append(data[0])
                complexity_test.append(compute_complexity(data[0]))
                if i > 1000:
                    break
        elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
            for i, data in tqdm(enumerate(train_data,0)):
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                complexity_train.append(compute_complexity(data[0]))
                if i > 1000:
                    break
            for i, data in tqdm(enumerate(test_data,0)):
                test_tensor.append(data[0])
                complexity_test.append(compute_complexity(data[0]))
                if i > 1000:
                    break
    train_loader = torch.utils.data.DataLoader(train_tensor, batch_size=batch_size, shuffle=True, drop_last = False)
    test_loader = torch.utils.data.DataLoader(test_tensor, batch_size=batch_size, drop_last = False)
    return complexity_train, complexity_test, train_loader, test_loader,  


def dataload_mean_vis(dataset, batch_size, input_size, out_flag, var):
    transform = tv.transforms.Compose([tv.transforms.ToTensor(), tv.transforms.Resize((input_size,input_size)), nf.utils.Jitter(1 / 256.) ])    
    train_tensor = []
    test_tensor = []
    complexity_train = []
    complexity_test = []
    if dataset == "CIFAR10":
        train_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR10('./dataset/cifar10/', train=False, download=True, transform=transform)
    elif dataset == "CIFAR100":
        train_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=True,download=True, transform=transform)
        test_data = tv.datasets.CIFAR100('./dataset/cifar100/', train=False,download=True, transform=transform)
    elif dataset == "SVHN":
        train_data = tv.datasets.SVHN(root='./dataset/svhn', split='train', download=True, transform=transform)
        test_data = tv.datasets.SVHN(root='./dataset/svhn', split='test',download=True, transform=transform)
    elif dataset == "MNIST":
        train_data = tv.datasets.MNIST(root='./dataset/mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.MNIST(root='./dataset/mnist', train=False, download=True, transform=transform)
    elif dataset == 'FashionMNIST':
        train_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=True, download=True, transform=transform)
        test_data = tv.datasets.FashionMNIST(root='./dataset/fashion_mnist', train=False, download=True, transform=transform)
    elif dataset =='celebA':
        train_data = tv.datasets.CelebA(root='./dataset/celebA', split='train', download=True, transform=transform)
        test_data = tv.datasets.CelebA(root='./dataset/celebA', split='test', download=True, transform=transform)
    elif dataset =='omniglot':
        train_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=True, download=True, transform=transform)
        test_data = tv.datasets.Omniglot(root='./dataset/omniglot', background=False, download=True, transform=transform)
    print(f"# of Train Data : {len(train_data)}")
    print(f"# of Test Data : {len(test_data)}")
    if out_flag == False:
        if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
            for i, data in tqdm(enumerate(train_data,0)):
                #repeated_data = data[0].repeat(3,1,1)
                #train_tensor.append(repeated_data)
                #complexity_train.append(compute_complexity(repeated_data))
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                    break
                complexity_train.append(compute_complexity(data[0]))

            for i, data in tqdm(enumerate(test_data,0)):
                #repeated_data = data[0].repeat(3,1,1)
                #test_tensor.append(repeated_data)
                #complexity_test.append(compute_complexity(repeated_data))
                test_tensor.append(data[0])
                complexity_test.append(compute_complexity(data[0]))

        elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
            for i, data in tqdm(enumerate(train_data,0)):
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                    break
                complexity_train.append(compute_complexity(data[0]))
            mean_complexity = np.mean(complexity_train)
            for i, data in tqdm(enumerate(test_data,0)):
                complexity_test.append(compute_complexity(data[0]))
                test_tensor.append(data[0])
        train_loader = torch.utils.data.DataLoader(train_tensor, batch_size=batch_size, shuffle=True, drop_last = False)
        test_loader = torch.utils.data.DataLoader(test_tensor, batch_size=batch_size, drop_last = False)
        return complexity_train, complexity_test, train_loader, test_loader    
    else:
        if dataset in ["MNIST", "FashionMNIST", 'omniglot']:
            for i, data in tqdm(enumerate(train_data,0)):
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                break
                complexity_train.append(compute_complexity(data[0]))
                
            tmp = torch.stack(train_tensor)
            print(tmp.shape)
            mean = torch.zeros(tmp[0].shape)
            print(mean.shape)
            for i in range(10000):
                data = mean
                test_tensor.append(data)
                complexity_test.append(compute_complexity(data))

        elif dataset in ["CIFAR10", "SVHN", "celebA", "CIFAR100"]:
            for i, data in tqdm(enumerate(train_data,0)):
                train_tensor.append(data[0])
                if i==0:
                    print(data[0].shape)
                complexity_train.append(compute_complexity(data[0]))
            tmp = torch.stack(train_tensor)
            print(tmp.shape)
            mean = torch.mean(tmp, axis=0)
            #mean = torch.full(tmp[0].shape, 30/255)
            print(mean)
            print(mean.shape)
            print(f"train PNG bit mean : {np.mean(complexity_train)}")
 
            for i in range(10000):
                #data = mean + torch.randn((3,32,32)) * 0.01
                #data = torch.full((3,32,32), mean) + torch.randn((3,32,32)) * 0.001
                data = mean + torch.randn((3,32,32)) * var
                test_tensor.append(data)
                complexity_test.append(compute_complexity(data))


        train_loader = torch.utils.data.DataLoader(train_tensor, batch_size=batch_size, shuffle=True, drop_last = False)
        test_loader = torch.utils.data.DataLoader(test_tensor, batch_size=batch_size, drop_last = False)
        return complexity_train, complexity_test, train_loader, test_loader

def estimate_kl(p_model, q_model, q_dataloader, device):
        kl_list=[]
        with torch.no_grad():
            for i, x in enumerate(iter(q_dataloader)):
                q_ll = q_model(x.to(device))
                p_ll = p_model(x.to(device))
                kl_list.extend(((q_ll-p_ll)).tolist())

        kl_list = np.array(kl_list)

        kl_list = kl_list[~np.isnan(kl_list)]
        kl_list = kl_list[~np.isinf(kl_list)].tolist()

        return np.mean(kl_list)

def estimate_kl_realnvp(p_model, q_model, q_dataloader, device):
        kl_list=[]
        with torch.no_grad():
            for i, x in enumerate(iter(q_dataloader)):
                q_ll = q_model.log_prob(x.to(device))
                p_ll = p_model.log_prob(x.to(device))
                kl_list.extend(((q_ll-p_ll)).tolist())

        kl_list = np.array(kl_list)

        kl_list = kl_list[~np.isnan(kl_list)]
        kl_list = kl_list[~np.isinf(kl_list)].tolist()

        return np.mean(kl_list)

def tsne_visualization(in_tensor, out_tensor):
    in_np = in_tensor.detach().cpu().numpy()
    out_np = out_tensor.detach().cpu.numpy()
    in_target = np.zeros(in_np.shape[0])
    out_target = np.ones(out_np.shape[0])

    tsne = TSNE(n_components=2, random_state=42)
