import argparse
import functools
import glob
import os
import shutil
import sys
import tarfile
from zipfile import ZipFile

from PIL import Image
import numpy as np
import torch
from scipy.io import loadmat
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm import tqdm, trange

_DEFAULT_DATA_ROOT = 'data'


def _prepare_mnist(out_dir):
    temp = 'mnist_temp'

    train = datasets.MNIST(root=temp, download=True, train=True).data
    test = datasets.MNIST(root=temp, download=True, train=False).data

    train = train.numpy()[..., np.newaxis]
    test = test.numpy()[..., np.newaxis]

    assert train.shape == (60000,28,28,1) and test.shape == (10000,28,28,1)
    assert train.dtype == test.dtype == np.uint8

    np.save(os.path.join(out_dir, 'train.npy'), train)
    np.save(os.path.join(out_dir, 'test.npy'), test)

    shutil.rmtree(temp)
    print('MNIST dataset successfully created.')


def _prepare_omniglot(out_dir):
    OMNIGLOT_URL = "https://github.com/yburda/iwae/raw/master/datasets/OMNIGLOT/chardata.mat"
    mat_path = os.path.join(out_dir, 'chardata.mat')

    if not os.path.exists(mat_path):
        _download_url(OMNIGLOT_URL, mat_path)
    raw_data = loadmat(mat_path)

    def reshape_data(data):
        return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='F')

    train = reshape_data(raw_data['data'].T.astype('float32'))
    test = reshape_data(raw_data['testdata'].T.astype('float32'))

    n_valid = 1345
    np.random.seed(1234)
    np.random.shuffle(train)

    train, val = train[:-n_valid], train[-n_valid:]
    val = np.random.binomial(1, val).astype(np.float32)
    test = np.random.binomial(1, test).astype(np.float32)

    train = train.reshape(-1, 28, 28, 1)
    val = val.reshape(-1, 28, 28, 1)
    test = test.reshape(-1, 28, 28, 1)

    assert train.shape == (23000, 28, 28, 1) and val.shape == (1345, 28, 28, 1) and test.shape == (8070, 28, 28, 1)
    assert train.dtype == val.dtype == test.dtype == np.float32

    np.save(os.path.join(out_dir, 'train.npy'), train)
    np.save(os.path.join(out_dir, 'val.npy'), train)
    np.save(os.path.join(out_dir, 'test.npy'), test)

    os.remove(mat_path)
    print('Omniglot (binarized) dataset successfully created.')


def _prepare_cifar10(out_dir):
    temp = os.path.join(out_dir, 'cifar10_temp')

    train = datasets.CIFAR10(root=temp, download=True, train=True).data
    test = datasets.CIFAR10(root=temp, download=True, train=False).data

    assert train.shape == (50000,32,32,3) and test.shape == (10000,32,32,3)
    assert train.dtype == test.dtype == np.uint8

    np.save(os.path.join(out_dir, 'train.npy'), train)
    np.save(os.path.join(out_dir, 'test.npy'), test)

    shutil.rmtree(temp)
    print('CIFAR10 dataset successfully created.')


def _prepare_oord_imagenet(out_dir, *, size, data_root):
    # Assume the user has extracted the archives into the following folders:
    #     DATA_ROOT/oord_imagenet32/{train,valid}_32x32
    #     DATA_ROOT/oord_imagenet64/{train,valid}_64x64
    # inside the folder f'{DATA_ROOT}/oord_imagenet'

    assert size in (32, 64)

    train_dir = os.path.join(out_dir, f'train_{size}x{size}')
    valid_dir = os.path.join(out_dir, f'valid_{size}x{size}')
    assert os.path.isdir(train_dir) and os.path.isdir(valid_dir)

    train_files = [os.path.join(train_dir, fn) for fn in os.listdir(train_dir)]
    valid_files = [os.path.join(valid_dir, fn) for fn in os.listdir(valid_dir)]
    assert len(train_files) == 1281149 and len(valid_files) == 49999

    def _process(split, files):
        array_fn = os.path.join(out_dir, f'{split}.npy')
        if os.path.exists(array_fn):
            print(f'{array_fn} already exists, skipping.')
        else:
            images = []
            for img_fn in tqdm(files, desc=f'Creating {split} set for OordImageNet {size}x{size}'):
                images.append(np.array(Image.open(img_fn)))
            images = np.stack(images, axis=0)
            np.save(array_fn, images)
            assert images.shape == (len(files), size, size, 3)

    _process('valid', valid_files)
    _process('train', train_files)

    print(f'Downsampled ImageNet {size}x{size} dataset successfully created.')


def _download_url(url, fpath):
    import urllib
    assert not os.path.exists(fpath)
    urllib.request.urlretrieve(url, fpath)


def _prepare_celebahq(out_dir, *, size):
    if os.path.exists(os.path.join(out_dir, f'train_{size}x{size}.npy')):
        print(f'CelebA-HQ train_{size}x{size} already exists!')
        return

    import tensorflow as tf
    tar_path = os.path.join(out_dir, 'celeba-tfr.tar')
    tfr_path = os.path.join(out_dir, 'celeba-tfr')
    if not os.path.exists(tfr_path):
        if not os.path.exists(tar_path):
            print(f'Downloading CelebA-HQ tar archive...')
            TFR_URL = 'https://storage.googleapis.com/glow-demo/data/celeba-tfr.tar'
            _download_url(TFR_URL, tar_path)
        assert os.path.exists(tar_path)

        tar_file = tarfile.open(tar_path)
        tar_file.extractall(out_dir)
    assert os.path.isdir(tfr_path) and len(os.listdir(tfr_path)) == 2

    def _resize_image_array(images, size, interpolation=Image.BILINEAR):
        assert type(size) == tuple
        N, origH, origW, C = images.shape  # Assume NHWC
        assert C in (1, 3) and images.dtype == np.uint8

        if size == (origH, origW):
            return images

        resized = []
        for img in images:
            pil = Image.fromarray(img.astype('uint8'), 'RGB')
            pil = pil.resize(size, resample=interpolation)
            resized.append(np.array(pil))

        resized = np.stack(resized, axis=0)
        assert resized.shape == (N, *size, C)

        return resized

    def _process_folder(split):
        split_str = {'train': 'train', 'valid': 'validation'}[split]
        filenames = sorted(glob.glob(os.path.join(tfr_path, split_str, '*.tfrecords')))
        dataset = tf.data.TFRecordDataset(filenames=filenames)
        processed = []
        attrs = []

        for example in tqdm(dataset, desc=f'Processing {split} set...'):
            parsed = tf.train.Example.FromString(example.numpy())
            shape = parsed.features.feature['shape'].int64_list.value
            attr = list(parsed.features.feature['attr'].int64_list.value)
            img_bytes = parsed.features.feature['data'].bytes_list.value[0]
            img = np.frombuffer(img_bytes, dtype=np.uint8).reshape(shape)
            processed.append(img)
            attrs.append(attr)

        processed = np.stack(processed)
        assert len(processed) == len(attrs) == {'train': 27000, 'valid': 3000}[split]
        assert processed.shape[1:] == (256, 256, 3)

        out_path = os.path.join(out_dir, f'{split}_{size}x{size}.npy')
        resized = _resize_image_array(processed, (size, size))
        assert resized.shape == (processed.shape[0], size, size, processed.shape[3])
        np.save(out_path, resized)
        np.save(os.path.join(out_dir, f'{split}_{size}x{size}_attr.npy'),
                np.array(attrs, dtype=np.int32))
        print(f'Saved {out_path}')

    _process_folder('train')
    _process_folder('valid')
    print(f'CelebA-HQ {size}x{size} successfully created.')


class MNIST(Dataset):
    def __init__(self, *, split, data_root):
        self.split = split
        self.image_shape = (1, 28, 28)
        self.data_root = data_root

        if split in ('train', 'test'):
            self.data = np.load(os.path.join(self.data_root, f'mnist/{self.split}.npy'))
        else:
            raise ValueError(f'Invalid split: {split}')
        self.data = torch.from_numpy(self.data.transpose(0,3,1,2))

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)


class CIFAR10(Dataset):
    def __init__(self, *, split, data_root):
        self.split = split
        self.image_shape = (3, 32, 32)
        self.data_root = data_root

        if split  == 'train':
            self.data = np.load(os.path.join(self.data_root, 'cifar10/train.npy'))
        elif split == 'test':
            self.data = np.load(os.path.join(self.data_root, 'cifar10/test.npy'))
        else:
            raise ValueError(f'Invalid split: {split}')

        self.data = torch.from_numpy(self.data.transpose(0,3,1,2))

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)


class _CelebAHQBase(Dataset):
    image_size = None

    def __init__(self, *, split, data_root):
        self.split = split
        self.image_shape = (3, self.image_size, self.image_size)
        self.data_root = data_root

        if split  == 'train':
            self.data = np.load(os.path.join(
                self.data_root, f'celebahq/train_{self.image_size}x{self.image_size}.npy'))
            assert self.data.shape == (27000, self.image_size, self.image_size, 3)
        elif split == 'valid':
            self.data = np.load(os.path.join(
                self.data_root, f'celebahq/valid_{self.image_size}x{self.image_size}.npy'))
            assert self.data.shape == (3000, self.image_size, self.image_size, 3)
        elif split == 'full':
            self.data = np.concatenate([
                np.load(os.path.join(
                    self.data_root, f'celebahq/train_{self.image_size}x{self.image_size}.npy')),
                np.load(os.path.join(
                    self.data_root, f'celebahq/valid_{self.image_size}x{self.image_size}.npy')),
            ])
            assert self.data.shape == (30000, self.image_size, self.image_size, 3)
        else:
            raise ValueError(f'Invalid split: {split}')

        self.data = torch.from_numpy(self.data.transpose(0,3,1,2))

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)


class CelebAHQ32(_CelebAHQBase):
    image_size = 32


class CelebAHQ64(_CelebAHQBase):
    image_size = 64


class CelebAHQ128(_CelebAHQBase):
    image_size = 128


class CelebAHQ256(_CelebAHQBase):
    image_size = 256


class _OordImageNetBase(Dataset):
    image_size = None

    def __init__(self, *, split, data_root, mmap=True):
        self.split = split
        self.mmap = mmap
        self.image_shape = (3, self.image_size, self.image_size)
        self.data_root = data_root

        if split in ('train', 'valid'):
            fn = os.path.join(self.data_root, f'oord_imagenet{self.image_size}/{split}.npy')
            self.data = np.load(fn, mmap_mode=('r' if self.mmap else None))
        else:
            raise ValueError(f'Invalid split: {split}')

        self.data = torch.from_numpy(self.data.transpose(0,3,1,2))

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)


class OordImageNet32(_OordImageNetBase):
    image_size = 32


class OordImageNet64(_OordImageNetBase):
    image_size = 64


# Can't use lambda func due to distributed training
def _convert(x):
    return x.convert('RGB')
def _tobyte(x):
    return (x * 255).byte()
class _ImageNetBase(Dataset):
    image_size = None

    def __init__(self, *, split, data_root):
        self.split = split
        self.image_shape = (3, self.image_size, self.image_size)

        if split == 'train':
            self.image_dir =  os.path.join(data_root, 'imagenet/train')
        elif split == 'valid':
            self.image_dir = os.path.join(data_root, 'imagenet/val')
        else:
            raise ValueError(f'Invalid split: {split}')

        self.transforms = transforms.Compose([
            transforms.Lambda(_convert),
            transforms.Resize(self.image_size),
            transforms.CenterCrop(self.image_size),
            transforms.ToTensor(),
            transforms.Lambda(_tobyte),
        ])
        self.image_paths = [
            image_path for idx, image_path in
            enumerate(sorted(glob.glob(os.path.join(self.image_dir, '*.JPEG'))))
        ]
        assert len(self.image_paths) > 0

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        x = self.transforms(Image.open(image_path))
        assert x.shape == (3, self.image_size, self.image_size), \
                f'Wrong shape {x.shape} for ImageNet split={self.split} index={idx}'
        return x

    def __len__(self):
        return len(self.image_paths)


class ImageNet128(_ImageNetBase):
    image_size = 128


class ImageNet256(_ImageNetBase):
    image_size = 256



def prepare_dataset(args):
    mapping = {
        'mnist': _prepare_mnist,
        'omniglot': _prepare_omniglot,
        'cifar10': _prepare_cifar10,
        'celebahq32': functools.partial(_prepare_celebahq, size=32),
        'celebahq64': functools.partial(_prepare_celebahq, size=64),
        'celebahq128': functools.partial(_prepare_celebahq, size=128),
        'celebahq256': functools.partial(_prepare_celebahq, size=256),
        'oord_imagenet32': functools.partial(_prepare_oord_imagenet, size=32),
        'oord_imagenet64': functools.partial(_prepare_oord_imagenet, size=64),
    }
    if args.dataset not in mapping:
        raise ValueError(f'Invalid dataset name {args.dataset}')
    if args.dataset.startswith('celebahq'):
        out_dir = os.path.join(args.data_root, 'celebahq')
    else:
        out_dir = os.path.join(args.data_root, args.dataset)

    os.makedirs(out_dir, exist_ok=True)
    mapping[args.dataset](out_dir)


def get_dataset(dataset, *args, data_root=None, **kwargs):
    if data_root is None:
        data_root = _DEFAULT_DATA_ROOT
    kwargs['data_root'] = data_root

    return {
        'mnist': MNIST,
        'cifar10': CIFAR10,
        'celebahq32': CelebAHQ32,
        'celebahq64': CelebAHQ64,
        'celebahq128': CelebAHQ128,
        'celebahq256': CelebAHQ256,
        'oord_imagenet32': OordImageNet32,
        'oord_imagenet64': OordImageNet64,
        'imagenet128': ImageNet128,
        'imagenet256': ImageNet256,
    }[dataset](*args, **kwargs)


##### TESTS #####


def test_mnist():
    data = get_dataset('mnist', split='train')
    assert len(data) == 60000
    x = torch.stack([data[0], data[1], data[-2], data[-1]], dim=0)
    assert x.dtype == torch.uint8 and x.shape == (4, 1, 28, 28)
    assert x.sum().item() == 99968

    data = get_dataset('mnist', split='test')
    assert len(data) == 10000
    x = torch.stack([data[0], data[1], data[-2], data[-1]], dim=0)
    assert x.dtype == torch.uint8 and x.shape == (4, 1, 28, 28)
    assert x.sum().item() == 115696


def test_cifar10():
    ds = get_dataset('cifar10', split='train')
    x = torch.stack([ds[0], ds[1], ds[-2], ds[-1]], dim=0)
    assert x.dtype == torch.uint8 and x.shape == (4, 3, 32, 32)
    assert x.float().mean().item() == 133.779541015625

    ds = get_dataset('cifar10', split='test')
    x = torch.stack([ds[0], ds[1], ds[-2], ds[-1]], dim=0)
    assert x.dtype == torch.uint8 and x.shape == (4, 3, 32, 32)
    assert x.float().mean().item() == 119.116943359375


def test_celebahq64():
    dataset = get_dataset('celebahq64', split='train')
    x = torch.stack([dataset[0], dataset[1], dataset[-2], dataset[-1]], dim=0)
    assert len(dataset) == 27000
    assert x.shape == (4, 3, 64, 64) and x.dtype == torch.uint8
    assert x.sum() == 4881970

    dataset = get_dataset('celebahq64', split='valid')
    x = torch.stack([dataset[0], dataset[1], dataset[-2], dataset[-1]], dim=0)
    assert len(dataset) == 3000
    assert x.shape == (4, 3, 64, 64) and x.dtype == torch.uint8
    assert x.sum() == 5172284

    dataset = get_dataset('celebahq64', split='full')
    x = torch.stack([dataset[0], dataset[1], dataset[-2], dataset[-1]], dim=0)
    assert len(dataset) == 30000
    assert x.shape == (4, 3, 64, 64) and x.dtype == torch.uint8
    assert x.sum() == 4956676


def test_celebahq128():
    dataset = get_dataset('celebahq128', split='full')
    x = torch.stack([dataset[0], dataset[1], dataset[-2], dataset[-1]], dim=0)
    assert len(dataset) == 30000
    assert x.shape == (4, 3, 128, 128) and x.dtype == torch.uint8
    assert x.sum() == 20168755


def test_oord_imagenet32():
    dataset = get_dataset('oord_imagenet32', split='train')
    x = torch.stack([dataset[0], dataset[1], dataset[-2], dataset[-1]], dim=0)
    assert len(dataset) == 1281149
    assert x.shape == (4, 3, 32, 32) and x.dtype == torch.uint8
    assert x.sum() == 1476073

    dataset = get_dataset('oord_imagenet32', split='valid')
    x = torch.stack([dataset[0], dataset[1], dataset[-2], dataset[-1]], dim=0)
    assert len(dataset) == 49999
    assert x.shape == (4, 3, 32, 32) and x.dtype == torch.uint8
    assert x.sum() == 1365737


def test_oord_imagenet64():
    dataset = get_dataset('oord_imagenet64', split='train')
    x = torch.stack([dataset[0], dataset[1], dataset[-2], dataset[-1]], dim=0)
    assert len(dataset) == 1281149
    assert x.shape == (4, 3, 64, 64) and x.dtype == torch.uint8
    assert x.sum() == 5319624

    dataset = get_dataset('oord_imagenet64', split='valid')
    x = torch.stack([dataset[0], dataset[1], dataset[-2], dataset[-1]], dim=0)
    assert len(dataset) == 49999
    assert x.shape == (4, 3, 64, 64) and x.dtype == torch.uint8
    assert x.sum() == 5447102


def test_imagenet128():
    # TODO: Add test for train set
    dataset = get_dataset('imagenet128', split='valid')
    x = torch.stack([dataset[0], dataset[1], dataset[-2], dataset[-1]], dim=0)
    assert len(dataset) == 50000
    assert x.shape == (4, 3, 128, 128) and x.dtype == torch.uint8
    assert x.sum() == 28558047


def test_imagenet256():
    # TODO: Add test for train set
    dataset = get_dataset('imagenet256', split='valid')
    x = torch.stack([dataset[0], dataset[1], dataset[-2], dataset[-1]], dim=0)
    assert len(dataset) == 50000
    assert x.shape == (4, 3, 256, 256) and x.dtype == torch.uint8
    assert x.sum() == 114236788



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('command', type=str, choices=['prepare'])
    parser.add_argument('--dataset', type=str)
    parser.add_argument('--data_root', type=str, default=_DEFAULT_DATA_ROOT)
    args = parser.parse_args()

    if args.command == 'prepare':
        prepare_dataset(args)
    else:
        raise ValueError(f'Invalid command {args.command}')
