import os
import subprocess
import h5py

import numpy as np
import torch

import torchvision
from src.dataloaders.disent.base import DisentangledDataset

THIS_PATH = os.path.dirname(__file__)
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '..', '..', '..', '..', '..'))
IMAGE_PATH = os.path.join(ROOT_PATH, 'datasets/disentanglement/shapes3d')

urls = {
    "train":
        "https://storage.googleapis.com/3d-shapes/3dshapes.h5"
}
files = {"train": "3dshapes.h5"}
lat_names = ('floorCol', 'wallCol', 'objCol', 'objSize', 'objType', 'objAzimuth')
lat_sizes = np.array([10, 10, 10, 8, 4, 15])
img_size = (3, 64, 64)

class Shapes3D(DisentangledDataset):
    """Shapes3D Dataset from [1].

    3dshapes is a dataset of 3D shapes procedurally generated from 6 ground truth independent
    latent factors. These factors are floor colour (10), wall colour (10), object colour (10), size (8), type (4) and azimuth (15).
    All possible combinations of these latents are present exactly once, generating N = 480000 total images.

    Notes
    -----
    - Link : https://storage.googleapis.com/3d-shapes
    - hard coded metadata because issue with python 3 loading of python 2

    Parameters
    ----------
    root : string
        Root directory of dataset.

    References
    ----------
    [1] Hyunjik Kim, Andriy Mnih (2018). Disentangling by Factorising.

    """

    def __init__(self, root=IMAGE_PATH, transforms_list=[torchvision.transforms.ToTensor()]):
        super(Shapes3D, self).__init__(root=root, transforms_list=transforms_list)

        self.train_data = os.path.join(root, files["train"])

        if not os.path.isdir(root):
            # self.logger.info("Downloading {} ...".format(str(type(self))))
            self.download()

        self.data = np.load(self.train_data.replace('.h5', '_imgs.npy'))
        self.latents_values = np.load(self.train_data.replace('.h5', '_labels.npy'))
        self.latents_classes = np.load(self.train_data.replace('.h5', '_labels.npy'))
        self.factor_num = self.latents_classes.shape[-1]
        self.factor_dict = {0: 10, 1: 10, 2: 10, 3: 8, 4: 4, 5: 15}
        self.lat_names = lat_names
        self.lat_sizes = lat_sizes

    def download(self):
        """Download the dataset."""
        os.makedirs(self.root)
        subprocess.check_call([
            "curl", "-L",
            urls["train"], "--output", self.train_data
        ])
        #For faster loading, a numpy copy will be created (reduces loading times by 300% at the cost of more storage).
        with h5py.File(self.train_data, 'r') as dataset:
            imgs = dataset['images'][()]
            lat_values = dataset['labels'][()]

            # latents_classes = np.array(file["labels"][:])
            lat_values = lat_values.astype(float)
            # convert float to int scales
            lat_values[:, 0] = lat_values[:, 0] * 10  # [0, 1, ..., 9]
            lat_values[:, 1] = lat_values[:, 1] * 10  # [0, 1, ..., 9]
            lat_values[:, 2] = lat_values[:, 2] * 10  # [0, 1, ..., 9]
            lat_values[:, 3] = np.round(
                lat_values[:, 3], 2
            )  # first round, since values are very precise floats
            remap = {
                0.75: 0.0,
                0.82: 1.0,
                0.89: 2.0,
                0.96: 3.0,
                1.04: 4.0,
                1.11: 5.0,
                1.18: 6.0,
                1.25: 7.0,
            }
            label_3 = np.copy(lat_values[:, 3])
            for k, v in remap.items():
                label_3[lat_values[:, 3] == k] = v
            lat_values[:, 3] = label_3
            # shape is already on int scale   # # [0, 1, ..., 3]
            lat_values[:, 5] = np.round(
                lat_values[:, 5], 2
            )  # first round, since values are very precise floats
            remap = {
                -30.0: 0,
                -25.71: 1,
                -21.43: 2,
                -17.14: 3,
                -12.86: 4,
                -8.57: 5,
                -4.29: 6,
                0.0: 7,
                4.29: 8,
                8.57: 9,
                12.86: 10,
                17.14: 11,
                21.43: 12,
                25.71: 13,
                30.0: 14,
            }
            label_5 = np.copy(lat_values[:, 5])
            for k, v in remap.items():
                label_5[lat_values[:, 5] == k] = v
            lat_values[:, 5] = label_5  # [0, 1, ..., 15]
            # make latents_classes an int, because
            # since 3 in latents_classes[:, 0] is actually 3.0000000000004, even though not correctly displayed
            lat_values = lat_values.astype(int)

        np.save(self.train_data.replace('.h5', '_imgs.npy'), imgs)
        np.save(self.train_data.replace('.h5', '_labels.npy'), lat_values)

    def __getitem__(self, idx):
        """Get the image of `idx`
        Return
        ------
        sample : torch.Tensor
            Tensor in [0.,1.] of shape `img_size`.

        lat_value : np.array
            Array of length 6, that gives the value of each factor of variation.
        """
        # ToTensor transforms numpy.ndarray (H x W x C) in the range
        # [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
        return self.transforms(self.data[idx]), torch.Tensor(self.latents_classes[idx])



    def factor_to_idx(self, factor):
        base = np.array(
            [10 * 10 * 8 * 4 * 15, 10 * 8 * 4 * 15, 8 * 4 * 15, 4 * 15, 15, 1]
        )
        idx = np.dot(factor, base)
        return idx




















