import os
import torchvision


from PIL import Image
from scipy.io import loadmat
# copy ILSVRC/ImageSets/CLS-LOC/train_cls.txt to ./root/
# to skip os walk (it's too slow) using ILSVRC/ImageSets/CLS-LOC/train_cls.txt file


class Pet(torchvision.datasets.VisionDataset):

    """Flower 102 Classification Dataset.

    Args:
        root (string): Root directory of the ImageNet Dataset.
        split (string, optional): The dataset split, supports ``train``, or ``val``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        wnids (list): List of the WordNet IDs.
        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(self, root, split='train', download=False, **kwargs):
        super(Pet, self).__init__(root, **kwargs)
        root = self.root = os.path.expanduser(root)
        self.split = self._verify_split(split)

        self.id2class = {   '1': 'Abyssinian',
                            '2': 'american_bulldog',
                            '3': 'american_pit_bull_terrier',
                            '4': 'basset_hound',
                            '5': 'beagle',
                            '6': 'Bengal',
                            '7': 'Birma',
                            '8': 'Bombay',
                            '9': 'boxer',
                            '10': 'British_Shorthair',
                            '11': 'chihuahua',
                            '12': 'Egyptian_Mau',
                            '13': 'english_cocker_spaniel',
                            '14': 'english_setter',
                            '15': 'german_shorthaired',
                            '16': 'great_pyrenees',
                            '17': 'havanese',
                            '18': 'japanese_chin',
                            '19': 'keeshond',
                            '20': 'leonberger',
                            '21': 'Maine_Coon',
                            '22': 'miniature_pinscher',
                            '23': 'newfoundland',
                            '24': 'Persian',
                            '25': 'pomeranian',
                            '26': 'pug',
                            '27': 'Ragdoll',
                            '28': 'Russian_Blue',
                            '29': 'saint_bernard',
                            '30': 'samoyed',
                            '31': 'scottish_terrier',
                            '32': 'shiba_inu',
                            '33': 'Siamese',
                            '34': 'Sphynx',
                            '35': 'staffordshire_bull_terrier',
                            '36': 'wheaten_terrier',
                            '37': 'yorkshire_terrier'
            }

        if split == 'train':
            listfile = os.path.join(root, 'annotations', 'trainval.txt')
        else:
            listfile = os.path.join(root, 'annotations', 'test.txt')

        if os.path.exists(listfile):
            with open(listfile, 'r') as f:
                samples = [
                    (line.strip().split(' ')[0], line.strip().split(' ')[1])
                    for line in f.readlines()
                    if line.strip()
                ]

            self.images = [s[0] for s in samples]
            self.targets = [int(s[1]) for s in samples]
        else:
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

    def __getitem__(self, item):
        # image
        image = Image.open(os.path.join(self.root, 'images', f'{self.images[item]}.jpg')).convert('RGB')  # (C, H, W)
        if self.transform:
            image = self.transform(image)

        # return image and label
        return image, self.targets[item] - 1  # count begin from zero

    def __len__(self):
        return len(self.images)

    def _verify_split(self, split):
        if split not in self.valid_splits:
            msg = "Unknown split {} .".format(split)
            msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
            raise ValueError(msg)
        return split

    @property
    def valid_splits(self):
        return 'train', 'test'

    @property
    def split_folder(self):
        return os.path.join(self.root, f'cars_{self.split}')


class Car(torchvision.datasets.VisionDataset):

    """Flower 102 Classification Dataset.

    Args:
        root (string): Root directory of the ImageNet Dataset.
        split (string, optional): The dataset split, supports ``train``, or ``val``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        wnids (list): List of the WordNet IDs.
        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(self, root, split='train', download=False, **kwargs):
        super(Car, self).__init__(root, **kwargs)
        root = self.root = os.path.expanduser(root)
        self.split = self._verify_split(split)

        if split == 'train':
            list_path = os.path.join(root, 'devkit', 'cars_train_annos.mat')
        else:
            list_path = os.path.join(root, 'cars_test_annos_withlabels.mat')

        if os.path.exists(list_path):
            list_mat = loadmat(list_path)
            self.images = [f.item() for f in list_mat['annotations']['fname'][0]]
            self.targets = [f.item() for f in list_mat['annotations']['class'][0]]
        else:
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

    def __getitem__(self, item):
        # image
        image = Image.open(os.path.join(self.split_folder, self.images[item])).convert('RGB')  # (C, H, W)
        if self.transform:
            image = self.transform(image)

        # return image and label
        return image, self.targets[item] - 1  # count begin from zero

    def __len__(self):
        return len(self.images)

    def _verify_split(self, split):
        if split not in self.valid_splits:
            msg = "Unknown split {} .".format(split)
            msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
            raise ValueError(msg)
        return split

    @property
    def valid_splits(self):
        return 'train', 'test'

    @property
    def split_folder(self):
        return os.path.join(self.root, f'cars_{self.split}')


class Aircraft(torchvision.datasets.VisionDataset):

    """Flower 102 Classification Dataset.

    Args:
        root (string): Root directory of the ImageNet Dataset.
        split (string, optional): The dataset split, supports ``train``, or ``val``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        wnids (list): List of the WordNet IDs.
        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(self, root, split='train', download=False, **kwargs):
        super(Aircraft, self).__init__(root, **kwargs)
        root = self.root = os.path.join(os.path.expanduser(root), 'data')
        self.split = self._verify_split(split)

        variants_dict = {}
        with open(os.path.join(root, 'variants.txt'), 'r') as f:
            for idx, line in enumerate(f.readlines()):
                variants_dict[line.strip()] = idx
        self.num_classes = len(variants_dict)

        if split == 'train':
            list_path = os.path.join(root, 'images_variant_trainval.txt')
        else:
            list_path = os.path.join(root, 'images_variant_test.txt')

        self.images = []
        self.targets = []
        with open(list_path, 'r') as f:
            for line in f.readlines():
                fname_and_variant = line.strip()
                self.images.append(fname_and_variant[:7])
                self.targets.append(variants_dict[fname_and_variant[7 + 1:]])

    def __getitem__(self, item):
        # image
        image = Image.open(os.path.join(self.root, 'images', '%s.jpg' % self.images[item])).convert('RGB')  # (C, H, W)
        if self.transform:
            image = self.transform(image)

        # return image and label
        return image, self.targets[item]  # count begin from zero

    def __len__(self):
        return len(self.images)

    def _verify_split(self, split):
        if split not in self.valid_splits:
            msg = "Unknown split {} .".format(split)
            msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
            raise ValueError(msg)
        return split

    @property
    def valid_splits(self):
        return 'train', 'test'


class Flower(torchvision.datasets.ImageFolder):

    """Flower 102 Classification Dataset.

    Args:
        root (string): Root directory of the ImageNet Dataset.
        split (string, optional): The dataset split, supports ``train``, or ``val``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        wnids (list): List of the WordNet IDs.
        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
        imgs (list): List of (image path, class_index) tuples
        targets (list): The class_index value for each image in the dataset
    """

    def __init__(self, root, split='train', download=False, **kwargs):
        root = self.root = os.path.expanduser(root)
        self.split = self._verify_split(split)
        super(Flower, self).__init__(self.split_folder, **kwargs)

    def _verify_split(self, split):
        if split not in self.valid_splits:
            msg = "Unknown split {} .".format(split)
            msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
            raise ValueError(msg)
        return split

    @property
    def valid_splits(self):
        return 'train', 'valid', 'test'

    @property
    def split_folder(self):
        return os.path.join(self.root, self.split)
