import os
import subprocess
import h5py
import random
import numpy as np
import torch

import torchvision
from dl.datasets.factorvae.base import DisentangledDataset

THIS_PATH = os.path.dirname(__file__)
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '../..', '..', '..', '..'))
IMAGE_PATH = os.path.join(ROOT_PATH, 'datasets/disentanglement/mpi3d_real')

urls = {
    "train":
        "https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_real.npz"
}
files = {"train": "mpi3d_real.npz"}
lat_names = ('objCol', 'objShape', 'objSize', 'cameraHeight', 'backCol', 'posX', 'posY')
lat_sizes = np.array([6, 6, 2, 3, 3, 40, 40])
img_size = (3, 64, 64)


# lat_values = {
#     'objCol': np.linspace(0, 5, 6),
#     'objShape': np.linspace(0, 5, 6),
#     'objSize': np.linspace(0, 1, 2),
#     'cameraHeight': np.linspace(0, 2, 3),
#     'backCol': np.linspace(0, 2, 3),
#     'posX': np.linspace(0, 39, 40),
#     'posY': np.linspace(0, 39, 40)
# }


class MPI3D_real_f(DisentangledDataset):
    """MPI3D Dataset as part of the NeurIPS 2019 Disentanglement Challenge.

    A data-set which consists of over one million images of physical 3D objects with seven factors of variation,
    such as object color, shape, size and position.

    Notes
    -----
    - Link : https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_toy.npz
    - hard coded metadata because issue with python 3 loading of python 2

    Parameters
    ----------
    root : string
        Root directory of dataset.

    """

    def __init__(self, root=IMAGE_PATH, transforms_list=[torchvision.transforms.ToTensor()]):
        super(MPI3D_real_f, self).__init__(root=root, transforms_list=transforms_list)

        self.train_data = os.path.join(root, files["train"])

        data = np.load(self.train_data)
        self.data = data['images']

        latents_classes = []
        # https://github.com/facebookresearch/disentangling-correlated-factors/blob/main/datasets/mpi3d_real.py
        lat_values = {
            "objCol": np.arange(6),
            "objShape": np.arange(6),
            "objSize": np.arange(2),
            "cameraHeight": np.arange(3),
            "backCol": np.arange(3),
            "posX": np.arange(40),
            "posY": np.arange(40),
        }
        for col in lat_values["objCol"]:
            for shp in lat_values["objShape"]:
                for siz in lat_values["objSize"]:
                    for hgt in lat_values["cameraHeight"]:
                        for bck in lat_values["backCol"]:
                            for x in lat_values["posX"]:
                                for y in lat_values["posY"]:
                                    latents_classes.append(
                                        [col, shp, siz, hgt, bck, x, y]
                                    )
        self.latents_classes = np.array(latents_classes)
        self.factor_num = self.latents_classes.shape[-1]
        self.factor_dict = {0: 4, 1: 4, 2: 2, 3: 3, 4: 3, 5: 40, 6: 40}
        assert self.factor_num == 7

    def download(self):
        """Download the dataset."""
        os.makedirs(self.root)
        subprocess.check_call([
            "curl", "-L",
            urls["train"], "--output", self.train_data
        ])

    def __getitem__(self, idx):
        """Get the image of `idx`
        Return
        ------
        sample : torch.Tensor
            Tensor in [0.,1.] of shape `img_size`.

        lat_value : np.array
            Array of length 6, that gives the value of each factor of variation.
        """
        # ToTensor transforms numpy.ndarray (H x W x C) in the range
        # [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
        idx2 = random.choice(range(6 * 6 * 2 * 3 * 3 * 40 * 40))

        return self.transforms(self.data[idx]), self.transforms(self.data[idx2]), torch.Tensor(self.latents_classes[idx])

    def factor_to_idx(self, factor):
        base = np.array(
            [6 * 2 * 3 * 3 * 40 * 40, 2 * 3 * 3 * 40 * 40, 3 * 3 * 40 * 40, 3 * 40 * 40, 40 * 40, 40, 1]
        )
        idx = np.dot(factor, base)
        return idx








