import os

import torch
import torchvision


class Caltech101:
    test_subset = None

    def __init__(self,
                 preprocess,
                 location=os.path.expanduser('~/data'),
                 batch_size=128,
                 num_workers=16,
                 subset='test',
                 classnames=None,
                 **kwargs):
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_location = os.path.join(location, 'caltech-101', 'train')
        self.train_dataset = torchvision.datasets.ImageFolder(
            root=self.train_location, transform=preprocess)
        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=self.num_workers)

        self.test_location = os.path.join(location, 'caltech-101',
                                          self.test_subset)
        print("Loading Test Data from ", self.test_location)
        self.test_dataset = torchvision.datasets.ImageFolder(
            root=self.test_location, transform=preprocess)
        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=self.num_workers)

        self.classnames = [
            'off-center face',
            'centered face',
            'leopard',
            'motorbike',
            'accordion',
            'airplane',
            'anchor',
            'ant',
            'barrel',
            'bass',
            'beaver',
            'binocular',
            'bonsai',
            'brain',
            'brontosaurus',
            'buddha',
            'butterfly',
            'camera',
            'cannon',
            'side of a car',
            'ceiling fan',
            'cellphone',
            'chair',
            'chandelier',
            'body of a cougar cat',
            'face of a cougar cat',
            'crab',
            'crayfish',
            'crocodile',
            'head of a  crocodile',
            'cup',
            'dalmatian',
            'dollar bill',
            'dolphin',
            'dragonfly',
            'electric guitar',
            'elephant',
            'emu',
            'euphonium',
            'ewer',
            'ferry',
            'flamingo',
            'head of a flamingo',
            'garfield',
            'gerenuk',
            'gramophone',
            'grand piano',
            'hawksbill',
            'headphone',
            'hedgehog',
            'helicopter',
            'ibis',
            'inline skate',
            'joshua tree',
            'kangaroo',
            'ketch',
            'lamp',
            'laptop',
            'llama',
            'lobster',
            'lotus',
            'mandolin',
            'mayfly',
            'menorah',
            'metronome',
            'minaret',
            'nautilus',
            'octopus',
            'okapi',
            'pagoda',
            'panda',
            'pigeon',
            'pizza',
            'platypus',
            'pyramid',
            'revolver',
            'rhino',
            'rooster',
            'saxophone',
            'schooner',
            'scissors',
            'scorpion',
            'sea horse',
            'snoopy (cartoon beagle)',
            'soccer ball',
            'stapler',
            'starfish',
            'stegosaurus',
            'stop sign',
            'strawberry',
            'sunflower',
            'tick',
            'trilobite',
            'umbrella',
            'watch',
            'water lilly',
            'wheelchair',
            'wild cat',
            'windsor chair',
            'wrench',
            'yin and yang symbol'
        ]


class Caltech101Val(Caltech101):
    def __init__(self, *args, **kwargs):
        self.test_subset = 'val'
        super().__init__(*args, **kwargs)


class Caltech101Test(Caltech101):
    def __init__(self, *args, **kwargs):
        self.test_subset = 'test'
        super().__init__(*args, **kwargs)
