import pdb
import os
import subprocess
import numpy as np
import torch

import torchvision
from dl.datasets.base import DisentangledDataset

THIS_PATH = os.path.dirname(__file__)
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '..', '..', '..', '..', '..'))
IMAGE_PATH = os.path.join(ROOT_PATH, 'datasets/disentanglement/dsprites')

urls = {
    "train":
        "https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true"
}
files = {"train": "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"}
lat_names = ('color', 'shape', 'scale', 'orientation', 'posX', 'posY')
lat_sizes = np.array([1, 3, 6, 40, 32, 32])
img_size = (1, 64, 64)

class dSprites_Semi(DisentangledDataset):
    """DSprites Dataset from [1].

    Disentanglement test Sprites dataset.Procedurally generated 2D shapes, from 6
    disentangled latent factors. This dataset uses 6 latents, controlling the color,
    shape, scale, rotation and position of a sprite. All possible variations of
    the latents are present. Ordering along dimension 1 is fixed and can be mapped
    back to the exact latent values that generated that image. Pixel outputs are
    different. No noise added.

    Notes
    -----
    - Link : https://github.com/deepmind/dsprites-dataset/
    - hard coded metadata because issue with python 3 loading of python 2

    Parameters
    ----------
    root : string
        Root directory of dataset.

    References
    ----------
    [1] Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick,
        M., ... & Lerchner, A. (2017). beta-vae: Learning basic visual concepts
        with a constrained variational framework. In International Conference
        on Learning Representations.

    """


    def __init__(self, root=IMAGE_PATH, transforms_list=[torchvision.transforms.ToTensor()], semi_ratio=0.5):
        super(dSprites_Semi, self).__init__(root=root, transforms_list=transforms_list)

        self.train_data = os.path.join(root, files["train"])

        if not os.path.isdir(root):
            # self.logger.info("Downloading {} ...".format(str(type(self))))
            self.download()

        dataset_zip = np.load(self.train_data)
        self.data = np.expand_dims(
            dataset_zip["imgs"] * 255, axis=-1
        )  # (# of datasets, C, H, W)
        # self.data = dataset_zip["imgs"]
        self.latents_values = dataset_zip["latents_values"]
        self.latents_classes = dataset_zip["latents_classes"][
            :, 1:
        ]  # (# of datasets, 6 - 1)
        self.factor_num = self.latents_values[:, 1:].shape[-1]
        self.factor_dict = {0: 3, 1: 6, 2: 40, 3: 32, 4: 32}
        assert self.factor_num == 5

        # set semi-supervised learning
        t_total = self.__len__()
        self.semi_idx = np.random.choice(
            t_total, round(t_total * semi_ratio), replace=False
        )

    def download(self):
        """Download the dataset."""
        os.makedirs(self.root)
        subprocess.check_call([
            "curl", "-L",
            urls["train"], "--output", self.train_data
        ])


    def __getitem__(self, idx):
        # sample = np.expand_dims(self.data[idx] * 255, axis=-1)
        data = self.transforms(self.data[idx])
        classes = torch.Tensor(self.latents_classes[idx])
        semi_classes = torch.Tensor(self.latents_classes[idx])
        if idx not in self.semi_idx:
            semi_classes += 9999
        return data, classes, semi_classes

    def factor_to_idx(self, factor):
        base = np.array([6 * 40 * 32 * 32, 40 * 32 * 32, 32 * 32, 32, 1])
        idx = np.dot(factor, base)
        return idx

