from torchvision.datasets import VisionDataset

from PIL import Image

import os
import os.path
import sys


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


class Caltech(VisionDataset):
    def __init__(self, root, split='train', transform=None, target_transform=None):
        super(Caltech, self).__init__(root, transform=transform, target_transform=target_transform)

        self.split = split # This defines the split you are going to use
                           # (split files are called 'train.txt' and 'test.txt')

        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        
        if split != "both":
            with open("{}/{}.txt".format("./caltech101", split), 'r') as f:
                image_file_names = f.readlines()
        else:
            with open("{}/{}.txt".format("./caltech101", 'test'), 'r') as f:
                image_file_names = f.readlines()
            
            with open("{}/{}.txt".format("./caltech101", 'train'), 'r') as f:
                image_file_names += f.readlines()
            

        self.data = [os.path.join(root, img.strip()) for img in image_file_names if img.split('/')[0] != "BACKGROUND_Google" ]
        image_file_class_names = [img.split('/')[-2] for img in self.data]
        self.num_unique_files = list(set(image_file_class_names))
        self.num_unique_files.sort()

        self.targets = [self.num_unique_files.index(img) for img in image_file_class_names]
        # print(self.targets)

    def __getitem__(self, index):
        '''
        __getitem__ should access an element through its index
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        '''

        image, label = self.data[index], self.targets[index]

        image = pil_loader(image)
        
        # Applies preprocessing when accessing the image
        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def __len__(self):
        '''
        The __len__ method returns the length of the dataset
        It is mandatory, as this is used by several other components
        '''
        length = len(self.targets) # Provide a way to get the length (number of elements) of the dataset
        return length

    