import torch
import os
import PIL
from torchvision.datasets.vision import VisionDataset


class Kodak(VisionDataset):
    """Kodak Lossless True Color Image Suite.

    Args:
        root (string): Root directory where images are downloaded to (default: 'data/').
            A subdirectory 'kodak' will be created.
        split (string): Currently only 'all'.
            Accordingly dataset is selected.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """
    url = "https://github.com/MohamedBakrAli/Kodak-Lossless-True-Color-Image-Suite/archive/master.zip"

    file_list = [
        # Filename     MD5 Hash
        ("master.zip", "26c60c55a88f9667a074e534b8b7867b"),
    ]
    base_folder = 'images'

    def __init__(self, root, split="all", transform=None, download=False):
        super().__init__(root, transform=transform)

        self.split = split
        self.root = os.path.join(root, 'kodak')

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')
        
        self.files = os.listdir(os.path.join(self.root,'images'))

    def _check_integrity(self):
        return os.path.isdir(os.path.join(self.root, 'images'))

    def download(self):
        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        from urllib.request import urlopen
        from tempfile import NamedTemporaryFile
        from shutil import unpack_archive, move, rmtree
        import glob
        import hashlib
        with urlopen(self.url) as zipresp, NamedTemporaryFile() as tfile:
            data = zipresp.read()
            assert hashlib.md5(data).hexdigest() == self.file_list[0][1], 'md5sum check failed' 
            tfile.write(data)

            tfile.seek(0)
            unpack_archive(tfile.name, self.root, format = 'zip')

        move(os.path.join(self.root,'Kodak-Lossless-True-Color-Image-Suite-master/PhotoCD_PCD0992/'),
             os.path.join(self.root,'images/'))
        rmtree(os.path.join(self.root,'Kodak-Lossless-True-Color-Image-Suite-master/'))

    def __getitem__(self, index):
        path = os.path.join(self.root, "images", self.files[index])
        X = pil_loader(path)

        X = self.transform(X) if self.transform is not None else X

        return X

    def __len__(self):
        return len(self.files)


def pil_loader(path):
    with open(path,'rb') as f:
        x = PIL.Image.open(f)
        return x.convert('RGB')
