from torchvision.datasets import VisionDataset

from PIL import Image

import os
import os.path
import sys

import random


class Caltech(VisionDataset):
    def __init__(self, root, split='train', transform=None, target_transform=None, n=30, seed=0, shuffle=False):
        super(Caltech, self).__init__(root, transform=transform, target_transform=target_transform)

        self.split = split # This defines the split you are going to use
                           # (split files are called 'train.txt' and 'test.txt')
        self.classes = range(101)

        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
        self.categories.remove("BACKGROUND_Google")  # this is not a real class
        self.index: List[int] = []
        self.y = []
        random.seed(seed)
        if self.split == 'train':
            for (i, c) in enumerate(self.categories):
                class_n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
                s = list(range(1, class_n + 1))
                random.shuffle(s)
                class_s = s[:n]
                self.index.extend(class_s)
                self.y.extend(len(class_s) * [i])
            if shuffle:
                random.seed(seed)
                random.shuffle(self.index)
                random.seed(seed)
                random.shuffle(self.y)
        elif self.split == 'test':
            for (i, c) in enumerate(self.categories):
                class_n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
                s = list(range(1, class_n + 1))
                random.shuffle(s)
                class_s = s[n:]
                self.index.extend(class_s)
                self.y.extend(len(class_s) * [i])
        else:
            print('split type error: %s'%self.split)
            sys.exit(1)

    def _pil_loader(self, path):
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def __getitem__(self, index):
        '''
        __getitem__ should access an element through its index
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        '''
        image = self._pil_loader(os.path.join(self.root, "101_ObjectCategories", self.categories[self.y[index]], f"image_{self.index[index]:04d}.jpg"))
        label = self.y[index]

        # Applies preprocessing when accessing the image
        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def __len__(self):
        '''
        The __len__ method returns the length of the dataset
        It is mandatory, as this is used by several other components
        '''
        return len(self.index)
