import os
from typing import Callable, Optional

import torch
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms

# CUB200 dataset class
# Download code from https://github.com/JH-LEE-KR/ContinualDatasets/blob/main/continual_datasets/continual_datasets.py
# by JH-LEE-KR


class CUB200(ImageFolder):

    def __init__(self,
                 root: str,
                 train: bool,
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 download: bool = False) -> None:

        self.root = os.path.expanduser(root)
        # self.url = 'https://data.deepai.org/CUB200(2011).zip'
        # self.filename = 'CUB200(2011).zip'

        # fpath = os.path.join(self.root, self.filename)
        # if not os.path.isfile(fpath):
        #     if not download:
        #         raise RuntimeError(
        #             'Dataset not found. You can use download=True to download it'
        #         )
        #     else:
        #         print('Downloading from ' + self.url)
        #         download_url(self.url, self.root, filename=self.filename)
        # if not os.path.exists(os.path.join(self.root, 'CUB_200_2011')):
        #     import zipfile
        #     zip_ref = zipfile.ZipFile(fpath, 'r')
        #     zip_ref.extractall(self.root)
        #     zip_ref.close()
        #     import tarfile
        #     tar_ref = tarfile.open(os.path.join(self.root, 'CUB_200_2011.tgz'),
        #                            'r')
        #     tar_ref.extractall(self.root)
        #     tar_ref.close()
        super().__init__(self.root + '/CUB200-2011/images',
                         transform=transforms.ToTensor()
                         if transform is None else transform,
                         target_transform=target_transform)

        with open(self.root + '/../train_test_split.txt', 'r') as f:
            split = [x.strip().split(' ')[-1] for x in f.readlines()]

        if train:
            self.samples = [
                x for idx, x in enumerate(self.samples) if split[idx] == '0'
            ]
        else:
            self.samples = [
                x for idx, x in enumerate(self.samples) if split[idx] == '1'
            ]
        self.targets = [s[1] for s in self.samples]

        self.classes_names = [
            x[4:].replace('_', ' ').lower() for x in self.classes
        ]

        self.mean, self.std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        if isinstance(self.transform, transforms.Compose):
            self.transform.transforms.append(
                transforms.Normalize(self.mean, self.std))


if __name__ == "__main__":
    dataset = CUB200(root="data", train=False, download=True)
    print(len(dataset), len(dataset.classes_names))
