#!/usr/bin/env python

import matplotlib.pyplot as plt
import numpy as np
import os
import requests
import torch
import torch.utils.data
from torchvision import transforms, utils

def GetDatasetIfNecessary():
    os.makedirs('data/dsprites', exist_ok=True)
    zip_file = 'data/dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'
    if not os.path.exists(zip_file):
        url = 'https://github.com/deepmind/dsprites-dataset/raw/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'
        print('Downloading {}'.format(url))
        response = requests.get(url)
        open(zip_file, 'wb').write(response.content)
    zip_train = 'data/dsprites/train.npz'
    zip_test = 'data/dsprites/test.npz'
    if not os.path.exists(zip_test):
        # Load full dataset
        print('Processing...')
        dataset_zip = np.load(zip_file, encoding='bytes')
        metadata = dataset_zip['metadata'][()]
        imgs = dataset_zip['imgs']
        latents_values = dataset_zip['latents_values']
        latents_classes = dataset_zip['latents_classes']

        # Shuffle deterministically
        data = list(zip(imgs, latents_values, latents_classes))
        rs = np.random.get_state()
        np.random.seed(0)
        np.random.shuffle(data)
        np.random.set_state(rs)

        # Split into train/test and save
        train_data = data[:-5000]
        test_data = data[-5000:]
        for data, file in [(train_data, zip_train), (test_data, zip_test)]:
            imgs, latents_values, latents_classes = zip(*data)
            np.savez_compressed(file, metadata=metadata, imgs=imgs,
                latents_values=latents_values, latents_classes=latents_classes)

class Dataset(torch.utils.data.Dataset):
    def __init__(self, train=True, transform=None):
        self.transform = transform
        GetDatasetIfNecessary()
        mode = 'train' if train else 'test'
        dataset_zip = np.load('data/dsprites/{}.npz'.format(mode), encoding='bytes')
        self.imgs = dataset_zip['imgs']
        self.latents_values = dataset_zip['latents_values']
        # self.latents_classes = dataset_zip['latents_classes']
        # self.metadata = dataset_zip['metadata'][()]

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, i):
        img = self.imgs[i]
        lvs = self.latents_values[i]
        if self.transform:
            img = self.transform(img)
        return img, lvs

if __name__ == '__main__':
    d = Dataset(mode='train')
    print(len(d))
