from torchvision.datasets.vision import VisionDataset
from PIL import Image
import os
import os.path
import numpy as np
from torchvision.datasets.utils import download_url, check_integrity, verify_str_arg
import torch
import torchvision.transforms as transforms


class SVHNMultiDigit(VisionDataset):
    split_list = {
        'train': ["https://nyc3.digitaloceanspaces.com/publicdata1/svhn-multi-digit-3x64x64_train.p",
                  "svhn-multi-digit-3x64x64_train.p", "25df8732e1f16fef945c3d9a47c99c1a"],
        'val': ["https://nyc3.digitaloceanspaces.com/publicdata1/svhn-multi-digit-3x64x64_val.p",
                "svhn-multi-digit-3x64x64_val.p", "fe5a3b450ce09481b68d7505d00715b3"],
        'test': ["https://nyc3.digitaloceanspaces.com/publicdata1/svhn-multi-digit-3x64x64_test.p",
                 "svhn-multi-digit-3x64x64_test.p", "332977317a21e9f1f5afe7ef47729c5c"]
    }

    def __init__(self, root, split='train',
                 transform=transforms.Compose([
                     transforms.RandomCrop([54, 54]),
                     transforms.ToTensor(),
                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
                 ]),
                 target_transform=None, download=False):
        super(SVHNMultiDigit, self).__init__(root, transform=transform, target_transform=target_transform)
        self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
        self.url = self.split_list[split][0]
        self.filename = self.split_list[split][1]
        self.file_md5 = self.split_list[split][2]

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        data = torch.load(os.path.join(self.root, self.filename))

        self.data = data[0]
        self.labels = data[1].type(torch.LongTensor)

    def __getitem__(self, index):
        img, target = self.data[index], int(self.labels[index])

        img = Image.fromarray(np.transpose(img.numpy(), (1, 2, 0)))

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

    def _check_integrity(self):
        root = self.root
        md5 = self.split_list[self.split][2]
        fpath = os.path.join(root, self.filename)
        return check_integrity(fpath, md5)

    def download(self):
        md5 = self.split_list[self.split][2]
        download_url(self.url, self.root, self.filename, md5)

    def extra_repr(self):
        return "Split: {split}".format(**self.__dict__)
