import os
import random
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from imageio import imwrite
import torch
import torchvision
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data.dataset import Subset
from PIL import Image


def set_seeds(seed=0, fully_deterministic=True):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    if fully_deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def myprint(statement, no_print=False, newline=True):
    if not no_print:
        if newline:
            print(statement)
        else:
            print(statement, end="")


def get_loader(dataset, path_dataset, bs=64, n_work=2, get_transform=False):
    if dataset.lower() in ["mnist", "fashionmnist"]:
        preproc_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        all_size = 60000
        train_size = 50000
        subset1_indices = list(range(0, train_size))
        subset2_indices = list(range(train_size, all_size))
        trainval_dataset = eval("datasets."+dataset)(
                os.path.join(path_dataset, "{}/".format(dataset)),
                train=True, download=True, transform=preproc_transform
        )
        train_dataset = Subset(trainval_dataset, subset1_indices)
        val_dataset   = Subset(trainval_dataset, subset2_indices)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=bs, shuffle=True,
            num_workers=n_work, pin_memory=False
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=bs, shuffle=False,
            num_workers=n_work, pin_memory=False
        )
        test_loader = torch.utils.data.DataLoader(
            eval("datasets."+dataset)(
                os.path.join(path_dataset, "{}/".format(dataset)),
                train=False, download=True, transform=preproc_transform
            ), batch_size=bs, shuffle=False,
            num_workers=n_work, pin_memory=False
        )
    elif dataset.lower() == "celeba":
        preproc_transform = transforms.Compose([
            transforms.CenterCrop(140),
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
        ])
        dset_train = datasets.CelebA(os.path.join(path_dataset, "CelebA/"), split="train", target_type="attr",
            transform=preproc_transform, target_transform=None, download=True)
        dset_valid = datasets.CelebA(os.path.join(path_dataset, "CelebA/"), split="valid", target_type="attr",
            transform=preproc_transform, target_transform=None, download=True)
        dset_test = datasets.CelebA(os.path.join(path_dataset, "CelebA/"), split="test", target_type="attr",
            transform=preproc_transform, target_transform=None, download=True)
        train_loader = torch.utils.data.DataLoader(dset_train,
            batch_size=bs, shuffle=True, num_workers=n_work, pin_memory=False
        )
        val_loader = torch.utils.data.DataLoader(dset_valid,
            batch_size=bs, shuffle=False, num_workers=n_work, pin_memory=False
        )
        test_loader = torch.utils.data.DataLoader(dset_test,
            batch_size=bs, shuffle=False, num_workers=n_work, pin_memory=False
        )
    elif dataset.lower() == "cifar10":
        preproc_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        all_size = 50000
        train_size = 40000
        subset1_indices = list(range(0, train_size))
        subset2_indices = list(range(train_size, all_size))
        trainval_dataset = datasets.CIFAR10(
                os.path.join(path_dataset, "{}/".format(dataset)), train=True, download=True,
                transform=preproc_transform
        )
        train_dataset = Subset(trainval_dataset, subset1_indices)
        val_dataset   = Subset(trainval_dataset, subset2_indices)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=bs, shuffle=True,
            num_workers=n_work, pin_memory=False
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=bs, shuffle=False,
            num_workers=n_work, pin_memory=False
        )
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                os.path.join(path_dataset, "{}/".format(dataset)), train=False, download=True,
                transform=preproc_transform
            ), batch_size=bs, shuffle=False,
            num_workers=n_work, pin_memory=False
        )
    elif dataset.lower() =='celeba-hq':
        train_dataset, val_dataset, test_dataset = get_loader_celeba_mask_hq(bs, path_dataset, imsize=256)
        preproc_transform = train_dataset.transform_img
        train_loader = train_dataset.loader()
        val_loader = val_dataset.loader()
        test_loader = test_dataset.loader()
    elif dataset.lower() == "ffhq":
        root = os.path.join(path_dataset, "FFHQ/")
        preproc_transform = transforms.Compose([
            transforms.CenterCrop(1024),
            transforms.Resize((1024, 1024)),
            transforms.ToTensor(),
        ])
        train_dataset = FFHQ(root, split='train', transform=preproc_transform)
        val_dataset = FFHQ(root, split='val', transform=preproc_transform)
        test_dataset = FFHQ(root, split='test', transform=preproc_transform)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=bs, shuffle=True,
            num_workers=n_work, pin_memory=False
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=bs, shuffle=False,
            num_workers=n_work, pin_memory=False
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=bs, shuffle=False,
            num_workers=n_work, pin_memory=False
        )
    elif dataset.lower() =='imagenet':
        preproc_transform = transforms.Compose([
            transforms.CenterCrop(256),
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
        root = os.path.join(path_dataset, "ILSVRC/Data/CLS-LOC/")
        train_dataset = ImageFolder(root=os.path.join(root, "train"), transform=preproc_transform)
        val_dataset = ImageFolder(root=os.path.join(root, "val_gen"), transform=preproc_transform)
        test_dataset = ImageFolder(root=os.path.join(root, "test_gen"), transform=preproc_transform)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=bs, shuffle=True,
            num_workers=n_work, pin_memory=False
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=bs, shuffle=False,
            num_workers=n_work, pin_memory=False
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=bs, shuffle=False,
            num_workers=n_work, pin_memory=False
        )


    if get_transform:
        return train_loader, val_loader, test_loader, preproc_transform
    else:
        return train_loader, val_loader, test_loader


class CelebAMaskHQ():
    def __init__(self, img_path, label_path, transform_img, transform_label, mode, type_data):
        self.img_path = img_path
        self.label_path = label_path
        self.transform_img = transform_img
        self.transform_label = transform_label
        self.train_dataset = []
        self.test_dataset = []
        self.mode = mode
        self.type = type_data
        self.preprocess()
        
        if mode == True:
            self.num_images = len(self.train_dataset)
        else:
            self.num_images = len(self.test_dataset)

    def preprocess(self):
        
        for i in range(len([name for name in os.listdir(self.img_path) if os.path.isfile(os.path.join(self.img_path, name))])):
            img_path = os.path.join(self.img_path, str(i)+'.jpg')
            label_path = os.path.join(self.label_path, str(i)+'.png')
            # print (img_path, label_path) 
            if self.mode == True:
                # self.train_dataset.append([img_path, label_path])
                self.train_dataset.append(img_path)
            else:
                # self.test_dataset.append([img_path, label_path])
                self.test_dataset.append(img_path)
            
        # print('Finished preprocessing the CelebA dataset...')

    def __getitem__(self, index):
        
        dataset = self.train_dataset if self.mode == True else self.test_dataset
        # img_path, label_path = dataset[index]
        img_path = dataset[index]
        if self.type == "both":
            image = Image.open(img_path)
            # label = Image.open(label_path)
            return self.transform_img(image), self.transform_label(label)
        elif self.type == "image":
            image = Image.open(img_path)
            return self.transform_img(image)
        elif self.type == "label":
            label = Image.open(label_path)
            return self.transform_label(label)

    def __len__(self):
        """Return the number of images."""
        return self.num_images

class Data_Loader():
    def __init__(self, img_path, label_path, image_size, batch_size, mode, type_data="image", gray=False):
        self.img_path = img_path
        self.label_path = label_path
        self.imsize = image_size
        self.batch = batch_size
        self.mode = mode
        self.gray = gray
        self.type = type_data

    def transform_img(self, resize, totensor, normalize, centercrop):
        options = []
        if centercrop:
            options.append(transforms.CenterCrop(self.imsize))
        if resize:
            options.append(transforms.Resize((self.imsize,self.imsize)))
        if totensor:
            options.append(transforms.ToTensor())
        if normalize:
            options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        if self.gray:
            options.append(transforms.Grayscale(1))
        transform = transforms.Compose(options)
        return transform

    def transform_label(self, resize, totensor, normalize, centercrop):
        options = []
        if centercrop:
            options.append(transforms.CenterCrop(self.imsize))
        if resize:
            options.append(transforms.Resize((self.imsize,self.imsize), interpolation=Image.NEAREST))
        if totensor:
            options.append(transforms.ToTensor())
        if normalize:
            options.append(transforms.Normalize((0, 0, 0), (0, 0, 0)))
        transform = transforms.Compose(options)
        return transform

    def loader(self):
        transform_img = self.transform_img(True, True, False, False) 
        transform_label = self.transform_label(True, True, False, False)  
        dataset = CelebAMaskHQ(self.img_path, self.label_path, transform_img, transform_label, self.mode, self.type)
        self.dataset = dataset

        loader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=self.batch,
                                            #  shuffle=True,
                                             shuffle=self.mode,
                                             num_workers=2,
                                             drop_last=False)
        return loader
    

def get_loader_celeba_mask_hq(bs, path, imsize, type_data="image", gray=False):
    path_train_img = os.path.join(path, "CelebAMask-HQ/train_img")
    path_train_label = os.path.join(path, "CelebAMask-HQ/train_label")
    train_loader = Data_Loader(path_train_img, path_train_label, imsize, bs, True, type_data, gray)
    path_val_img = os.path.join(path, "CelebAMask-HQ/val_img")
    path_val_label = os.path.join(path, "CelebAMask-HQ/val_label")
    val_loader = Data_Loader(path_val_img, path_val_label, imsize, bs, True, type_data, gray)
    path_test_img = os.path.join(path, "CelebAMask-HQ/test_img")
    path_test_label = os.path.join(path, "CelebAMask-HQ/test_label")
    test_loader = Data_Loader(path_test_img, path_test_label, imsize, bs, False, type_data, gray)
    
    return train_loader, val_loader, test_loader


class ImageFolder_(torchvision.datasets.VisionDataset):

    def __init__(self, root, train_list_file, val_list_file, test_list_file, split='train', **kwargs):

        root = Path(root)
        super().__init__(root, **kwargs)

        self.train_list_file = train_list_file
        self.val_list_file = val_list_file
        self.test_list_file = test_list_file

        # self.split = self._verify_split(split)
        self.split = split

        self.loader = torchvision.datasets.folder.default_loader
        self.extensions = torchvision.datasets.folder.IMG_EXTENSIONS

        if self.split == 'trainval':
            fname_list = os.listdir(self.root)
            samples = [self.root.joinpath(fname) for fname in fname_list
                       if fname.lower().endswith(self.extensions)]
        else:
            if self.split == 'train':
                listfile = self.train_list_file
            elif self.split == 'val':
                listfile = self.val_list_file
            elif self.split == 'test':
                listfile = self.test_list_file
            with open(listfile, 'r') as f:
                samples = [self.root.joinpath(line.strip()) for line in f.readlines()]

        self.samples = samples

    def _verify_split(self, split):
        if split not in self.valid_splits:
            msg = "Unknown split {} .".format(split)
            msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
            raise ValueError(msg)
        return split

    @property
    def valid_splits(self):
        return 'train', 'val', 'trainval'

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index, with_transform=True):
        path = self.samples[index]
        sample = self.loader(path)
        if self.transforms is not None and with_transform:
            sample, _ = self.transforms(sample, None)
        return sample, 0


class FFHQ(ImageFolder_):
    train_list_file = Path(__file__).parent.joinpath('dataset/ffhqtrain.txt')
    val_list_file = Path(__file__).parent.joinpath('dataset/ffhqvalidation.txt')
    test_list_file = Path(__file__).parent.joinpath('dataset/ffhqtest.txt')

    def __init__(self, root, split='train', **kwargs):
        super().__init__(root, FFHQ.train_list_file, FFHQ.val_list_file, FFHQ.test_list_file, split, **kwargs)


## Generate images
def plot_images(images, filename, nrows=4, ncols=8, flg_norm=False):
    if images.shape[1] == 1:
        images = np.repeat(images, 3, axis=1)
    fig = plt.figure(figsize=(nrows * 2, ncols), dpi=400)
    gs = gridspec.GridSpec(nrows * 2, ncols)
    gs.update(wspace=0.05, hspace=0.05)
    for i, image in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis("off")
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect("equal")
        if flg_norm:
            image_transposed = image.transpose((1,2,0)) * 0.5 + 0.5
        else:
            image_transposed = image.transpose((1,2,0))
        image_transposed = np.clip(image_transposed, 0., 1.)
        plt.imshow(image_transposed)

        # if flg_norm:
        #     plt.imshow(image.transpose((1,2,0)) * 0.5 + 0.5)
        # else:
        #     plt.imshow(image.transpose((1,2,0)))

    dirname = os.path.dirname(filename)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)
    plt.savefig(filename, bbox_inches="tight")
    plt.close(fig)


def plot_images_paper(images, filename, nrows=4, ncols=8, flg_norm=False):
    if images.shape[1] == 1:
        images = np.repeat(images, 3, axis=1)
    fig = plt.figure(figsize=(nrows, ncols))
    gs = gridspec.GridSpec(nrows, ncols)
    gs.update(wspace=0.05, hspace=0.05)
    for i, image in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis("off")
        ax.set_aspect("equal")
        if flg_norm:
            plt.imshow(image.transpose((1,2,0)) * 0.5 + 0.5)
        else:
            plt.imshow(image.transpose((1,2,0)))

    dirname = os.path.dirname(filename)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)
    plt.savefig(filename, bbox_inches="tight")
    plt.close(fig)


def save_images_for_fid_test(data_loader, dataset, dirname, comparison_size=50, flg_norm=False, fmt='.png'):
    bs = data_loader.batch_size
    dirname = os.path.join(dirname, dataset)
    # import pdb;pdb.set_trace()
    if os.path.exists(dirname):
        print("Already saved!")
    else:
        num_samples = 0
        with torch.no_grad():
            # for idx, (x, _) in enumerate(data_loader):
            for idx, data in enumerate(data_loader):
                if dataset == 'CelebA-HQ':
                    x = data.cuda()
                else:
                    x = data[0].cuda()
            # for idx, x in enumerate(data_loader):
                x = x.cuda()
                save_set_of_images(dirname, x, idx*bs)
                num_samples += x.shape[0]
                if comparison_size is not None:
                    if num_samples >= comparison_size:
                        break 
        print("GT Images were saved!")
    
    return dirname


def save_images_for_fid_reconst(model, data_loader, dataset, dirname, comparison_size=50, flg_norm=False, fmt='.png'):
    bs = data_loader.batch_size
    if os.path.exists(dirname):
        import shutil
        shutil.rmtree(dirname)
    os.makedirs(dirname)
    num_samples = 0
    with torch.no_grad():
        # for idx, (x, _) in enumerate(data_loader):
        for idx, data in enumerate(data_loader):
            if dataset == 'CelebA-HQ':
                x = data.cuda()
            else:
                x = data[0].cuda()
            x = x.cuda()
            output = model(x)
            x_reconst = output[0]
            if flg_norm:
                x_reconst = (x_reconst + 1) / 2
            save_set_of_images(dirname, x_reconst, idx*bs)
            num_samples += x.shape[0]
            # print(idx)
            if comparison_size is not None:
                if num_samples >= comparison_size:
                    break 
    print("Images were reconstructed!")

    return dirname


def save_set_of_images(dirname, images, idx_start):
    if not os.path.exists(dirname):
        os.mkdir(dirname)

    images = images.to('cpu').detach().numpy().copy()
    images = (np.clip(images, 0, 1) * 255).astype('uint8')
    images = np.transpose(images, (0,2,3,1))
    # print(images.shape)

    for i, img in enumerate(images):
        if img.shape[-1] == 1:
            # img = img[:, :, 0]
            img = np.tile(img, (1, 1, 3))
        imwrite(os.path.join(dirname, '%08d.png' % (i + idx_start)), img)