# MIT License

# Copyright (c) [2023] [Anima-Lab]


import io
import os
import json
import zipfile

import lmdb
import numpy as np
from PIL import Image
import torch
from torchvision.datasets import ImageFolder, VisionDataset



def center_crop_arr(pil_image, image_size):
    """
    Center cropping implementation from ADM.
    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])


################################################################################
# ImageNet - LMDB
###############################################################################

def lmdb_loader(path, lmdb_data, resolution):
    # In-memory binary streams
    with lmdb_data.begin(write=False, buffers=True) as txn:
        bytedata = txn.get(path.encode('ascii'))
    img = Image.open(io.BytesIO(bytedata)).convert('RGB')
    arr = center_crop_arr(img, resolution)
    # arr = arr.astype(np.float32) / 127.5 - 1
    # arr = np.transpose(arr, [2, 0, 1])  # CHW
    return arr


def imagenet_lmdb_dataset(
        root, 
        transform=None, target_transform=None, 
        resolution=256):
    """
    You can create this dataloader using:
    train_data = imagenet_lmdb_dataset(traindir, transform=train_transform)
    valid_data = imagenet_lmdb_dataset(validdir, transform=val_transform)
    """

    if root.endswith('/'):
        root = root[:-1]
    pt_path = os.path.join(
        root + '_faster_imagefolder.lmdb.pt')
    lmdb_path = os.path.join(
        root + '_faster_imagefolder.lmdb')
    if os.path.isfile(pt_path) and os.path.isdir(lmdb_path):
        print('Loading pt {} and lmdb {}'.format(pt_path, lmdb_path))
        data_set = torch.load(pt_path)
    else:
        data_set = ImageFolder(
            root, None, None, None)
        torch.save(data_set, pt_path, pickle_protocol=4)
        print('Saving pt to {}'.format(pt_path))
        print('Building lmdb to {}'.format(lmdb_path))
        env = lmdb.open(lmdb_path, map_size=1e12)
        with env.begin(write=True) as txn:
            for path, class_index in data_set.imgs:
                with open(path, 'rb') as f:
                    data = f.read()
                txn.put(path.encode('ascii'), data)

    lmdb_dataset = ImageLMDB(lmdb_path, transform, target_transform, resolution, data_set.imgs, data_set.class_to_idx, data_set.classes)
    return lmdb_dataset


################################################################################
# ImageNet Dataset class- LMDB
###############################################################################

class ImageLMDB(VisionDataset):
    """
    A data loader for ImageNet LMDB dataset, which is faster than the original ImageFolder.
    """
    def __init__(self, root, transform=None, target_transform=None, 
                 resolution=256, samples=None, class_to_idx=None, classes=None):
        super().__init__(root, transform=transform,
                         target_transform=target_transform)
        self.root = root
        self.resolution = resolution
        self.samples = samples
        self.class_to_idx = class_to_idx
        self.classes = classes
    
    def __getitem__(self, index: int):
        path, target = self.samples[index]

        # load image from path
        if not hasattr(self, 'txn'):
            self.open_db()
        bytedata = self.txn.get(path.encode('ascii'))
        img = Image.open(io.BytesIO(bytedata)).convert('RGB')
        arr = center_crop_arr(img, self.resolution)
        if self.transform is not None:
            arr = self.transform(arr)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return arr, target

    def __len__(self) -> int:
        return len(self.samples)

    def open_db(self):
        self.env = lmdb.open(self.root, readonly=True, max_readers=256, lock=False, readahead=False, meminit=False)
        self.txn = self.env.begin(write=False, buffers=True)



################################################################################
# ImageNet - LMDB - latent space
###############################################################################



# ----------------------------------------------------------------------------
# Abstract base class for datasets.

class Dataset(torch.utils.data.Dataset):
    def __init__(self,
                 name,  # Name of the dataset.
                 raw_shape,  # Shape of the raw image data (NCHW).
                 max_size=None,  # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
                 label_dim=1000,  # Ensure specific number of classes
                 xflip=False,  # Artificially double the size of the dataset via x-flips. Applied after max_size.
                 random_seed=0,  # Random seed to use when applying max_size.
                 ):
        self._name = name
        self._raw_shape = list(raw_shape)
        self._label_dim = label_dim
        self._label_shape = None

        # Apply max_size.
        self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
        if (max_size is not None) and (self._raw_idx.size > max_size):
            np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
            self._raw_idx = np.sort(self._raw_idx[:max_size])

        # Apply xflip. (Assume the dataset already contains the same number of xflipped samples)
        if xflip:
            self._raw_idx = np.concatenate([self._raw_idx, self._raw_idx + self._raw_shape[0]])

    def close(self):  # to be overridden by subclass
        pass

    def _load_raw_data(self, raw_idx):  # to be overridden by subclass
        raise NotImplementedError

    def __getstate__(self):
        return dict(self.__dict__, _raw_labels=None)

    def __del__(self):
        try:
            self.close()
        except:
            pass

    def __len__(self):
        return self._raw_idx.size

    def __getitem__(self, idx):
        raw_idx = self._raw_idx[idx]
        image, cond = self._load_raw_data(raw_idx)
        assert isinstance(image, np.ndarray)
        if isinstance(cond, list):  # [label, feature]
            cond[0] = self._get_onehot(cond[0])
        else:  # label
            cond = self._get_onehot(cond)
        return image.copy(), cond

    def _get_onehot(self, label):
        if isinstance(label, int) or label.dtype == np.int64:
            onehot = np.zeros(self.label_shape, dtype=np.float32)
            onehot[label] = 1
            label = onehot
        assert isinstance(label, np.ndarray)
        return label.copy()

    @property
    def name(self):
        return self._name

    @property
    def image_shape(self):
        return list(self._raw_shape[1:])

    @property
    def num_channels(self):
        assert len(self.image_shape) == 3  # CHW
        return self.image_shape[0]

    @property
    def resolution(self):
        assert len(self.image_shape) == 3  # CHW
        assert self.image_shape[1] == self.image_shape[2]
        return self.image_shape[1]

    @property
    def label_shape(self):
        if self._label_shape is None:
            self._label_shape = [self._label_dim]
        return list(self._label_shape)

    @property
    def label_dim(self):
        assert len(self.label_shape) == 1
        return self.label_shape[0]

    @property
    def has_labels(self):
        return True


# ----------------------------------------------------------------------------
# Dataset subclass that loads latent images recursively from the specified lmdb file.

class ImageNetLatentDataset(Dataset):
    def __init__(self,
                 path,  # Path to directory or zip.
                 resolution=32,  # Ensure specific resolution, default 32.
                 num_channels=4,  # Ensure specific number of channels, default 4.
                 split='train',  # train or val split
                 feat_path=None, # Path to features lmdb file (only works when feat_cond=True)
                 feat_dim=0,  # feature dim
                 **super_kwargs,  # Additional arguments for the Dataset base class.
                 ):
        self._path = os.path.join(path, split)
        self.feat_dim = feat_dim
        if not hasattr(self, 'txn'):
            self.open_lmdb()
        self.feat_txn = None
        if feat_path is not None and os.path.isdir(feat_path):
            assert self.feat_dim > 0
            self._feat_path = os.path.join(feat_path, split)
            self.open_feat_lmdb()

        length = int(self.txn.get('length'.encode('utf-8')).decode('utf-8'))
        name = os.path.basename(path)
        raw_shape = [length, num_channels, resolution, resolution]  # 1281167 x 4 x 32 x 32
        if raw_shape[2] != resolution or raw_shape[3] != resolution:
            raise IOError('Image files do not match the specified resolution')

        super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)

    def open_lmdb(self):
        self.env = lmdb.open(self._path, readonly=True, lock=False, create=False)
        self.txn = self.env.begin(write=False)

    def open_feat_lmdb(self):
        self.feat_env = lmdb.open(self._feat_path, readonly=True, lock=False, create=False)
        self.feat_txn = self.feat_env.begin(write=False)

    def _load_raw_data(self, idx):
        if not hasattr(self, 'txn'):
            self.open_lmdb()

        z_bytes = self.txn.get(f'z-{str(idx)}'.encode('utf-8'))
        y_bytes = self.txn.get(f'y-{str(idx)}'.encode('utf-8'))
        z = np.frombuffer(z_bytes, dtype=np.float32).reshape([-1, self.resolution, self.resolution]).copy()
        y = int(y_bytes.decode('utf-8'))

        cond = y
        if self.feat_txn is not None:
            feat_bytes = self.feat_txn.get(f'feat-{str(idx)}'.encode('utf-8'))
            feat_y_bytes = self.feat_txn.get(f'y-{str(idx)}'.encode('utf-8'))
            feat = np.frombuffer(feat_bytes, dtype=np.float32).reshape([self.feat_dim]).copy()
            feat_y = int(feat_y_bytes.decode('utf-8'))
            assert y == feat_y, 'Ordering mismatch between txn and feat_txn!'
            cond = [y, feat]

        return z, cond

    def close(self):
        try:
            if self.env is not None:
                self.env.close()
            if self.feat_env is not None:
                self.feat_env.close()
        finally:
            self.env = None
            self.feat_env = None


# ----------------------------------------------------------------------------
# Dataset subclass that loads images recursively from the specified directory or zip file.

class ImageFolderDataset(Dataset):
    def __init__(self,
                 path,  # Path to directory or zip.
                 resolution=None,  # Ensure specific resolution, None = highest available.
                 use_labels=False, # Enable conditioning labels? False = label dimension is zero.
                 **super_kwargs,  # Additional arguments for the Dataset base class.
                 ):
        self._path = path
        self._zipfile = None
        self._raw_labels = None
        self._use_labels = use_labels

        if os.path.isdir(self._path):
            self._type = 'dir'
            self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in
                                os.walk(self._path) for fname in files}
        elif self._file_ext(self._path) == '.zip':
            self._type = 'zip'
            self._all_fnames = set(self._get_zipfile().namelist())
        else:
            raise IOError('Path must point to a directory or zip')

        Image.init()
        self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in Image.EXTENSION)
        if len(self._image_fnames) == 0:
            raise IOError('No image files found in the specified path')

        name = os.path.splitext(os.path.basename(self._path))[0]
        raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
        if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
            raise IOError('Image files do not match the specified resolution')
        super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)

    @staticmethod
    def _file_ext(fname):
        return os.path.splitext(fname)[1].lower()

    def _get_zipfile(self):
        assert self._type == 'zip'
        if self._zipfile is None:
            self._zipfile = zipfile.ZipFile(self._path)
        return self._zipfile

    def _open_file(self, fname):
        if self._type == 'dir':
            return open(os.path.join(self._path, fname), 'rb')
        if self._type == 'zip':
            return self._get_zipfile().open(fname, 'r')
        return None

    def close(self):
        try:
            if self._zipfile is not None:
                self._zipfile.close()
        finally:
            self._zipfile = None

    def __getstate__(self):
        return dict(super().__getstate__(), _zipfile=None)

    def _load_raw_data(self, raw_idx):
        image = self._load_raw_image(raw_idx)
        assert image.dtype == np.uint8
        label = self._get_raw_labels()[raw_idx]
        return image, label

    def _load_raw_image(self, raw_idx):
        fname = self._image_fnames[raw_idx]
        with self._open_file(fname) as f:
            image = np.array(Image.open(f))
        if image.ndim == 2:
            image = image[:, :, np.newaxis]  # HW => HWC
        image = image.transpose(2, 0, 1)  # HWC => CHW
        return image

    def _get_raw_labels(self):
        if self._raw_labels is None:
            self._raw_labels = self._load_raw_labels() if self._use_labels else None
            if self._raw_labels is None:
                self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
            assert isinstance(self._raw_labels, np.ndarray)
            assert self._raw_labels.shape[0] == self._raw_shape[0]
            assert self._raw_labels.dtype in [np.float32, np.int64]
            if self._raw_labels.dtype == np.int64:
                assert self._raw_labels.ndim == 1
                assert np.all(self._raw_labels >= 0)
        return self._raw_labels

    def _load_raw_labels(self):
        fname = 'dataset.json'
        if fname not in self._all_fnames:
            return None
        with self._open_file(fname) as f:
            labels = json.load(f)['labels']
        if labels is None:
            return None
        labels = dict(labels)
        labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
        labels = np.array(labels)
        labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
        return labels

# ----------------------------------------------------------------------------
