from torchvision.datasets import VisionDataset
import os
import numpy as np
import h5py


class Shapes3D(VisionDataset):
	"""
	Shapes3D dataset.
	The data set was originally introduced in "Disentangling by Factorising".

	The ground-truth factors of variation are:
		0 - floor color (10 different values)
		1 - wall color (10 different values)
		2 - object color (10 different values)
		3 - object size (8 different values)
		4 - object type (4 different values)
		5 - azimuth (15 different values)

	You need to download the file:
	https://storage.cloud.google.com/3d-shapes/3dshapes.h5?_ga=2.264153871.-282447512.1586967252
	"""

	type_list = {
		"h5py": [
			"https://storage.cloud.google.com/3d-shapes/3dshapes.h5?_ga=2.264153871.-282447512.1586967252",
			"3dshapes.h5"
		]
	}

	def __init__(self, root, transform=None, target_transform=None):
		super(Shapes3D, self).__init__(root, transform=transform, target_transform=target_transform)

		self.dataset = None

		with h5py.File(os.path.join(self.processed_folder, self.type_list["h5py"][1]), 'r') as file:
			self.dataset_len = len(file["images"])
		# Label factors
		self.factor_sizes = [10, 10, 10, 8, 4, 15]

	def __getitem__(self, index):
		"""
		Args:
			index (int): Index

		Returns:
			tuple: (image, target).
		"""
		if self.dataset is None:
			self.dataset = h5py.File(os.path.join(self.processed_folder, self.type_list["h5py"][1]), 'r')

		img, target = self.dataset["images"][index], self.dataset["labels"][index]

		if self.transform is not None:
			img = self.transform(img)

		if self.target_transform is not None:
			target = self.target_transform(target)

		return (np.array(img) / 255.).reshape(64, 64, 3), np.array(target)

	def __len__(self):
		return self.dataset_len

	def _check_exists(self):
		return os.path.exists(self.processed_folder)

	@property
	def processed_folder(self):
		return os.path.join(self.root, self.__class__.__name__)

	@property
	def observation_shape(self):
		return [64, 64, 3]

	@property
	def factors_num_values(self):
		return self.factor_sizes


if __name__ == '__main__':
	shape3d = Shapes3D(root='.')
	print(shape3d.__len__)
	a, b = shape3d.__getitem__(0)
	import matplotlib.pyplot as plt

	plt.imshow(a)
	plt.show()
	print(shape3d._check_exists())
	print(shape3d.processed_folder)
	print(shape3d.factors_num_values)
