import os
import torch
import torchvision.datasets as datasets

class ImageFolderDataset(datasets.ImageFolder):
    def __init__(self, root, transform):
        super().__init__(root, transform)

    def __getitem__(self, index: int):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target, index

class ImageNet100aux:
    def __init__(self,
                 preprocess,
                 location=os.path.expanduser('./data'),
                 batch_size=32,
                 num_workers=16):
        # Data loading code
        location = './data'
        traindir = os.path.join(location, 'ImageNet100aux', 'train')
        valdir = os.path.join(location, 'ImageNet100aux', 'val')

        self.train_dataset = ImageFolderDataset(
            traindir, transform=preprocess)
        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            shuffle=True,
            batch_size=batch_size,
            num_workers=num_workers,
        )
        self.test_dataset = ImageFolderDataset(valdir, transform=preprocess)
        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=batch_size,
            num_workers=num_workers
        )
        self.test_loader_shuffle = torch.utils.data.DataLoader(
            self.test_dataset,
            shuffle=True,
            batch_size=batch_size,
            num_workers=num_workers
        )
        idx_to_class = dict((v, k)
                            for k, v in self.train_dataset.class_to_idx.items())
        self.classnames = [idx_to_class[i].replace(
            '_', ' ') for i in range(len(idx_to_class))]

        labels = {
        "n01616318": "vulture",
        "n02097130": "giant schnauzer",
        "n02364673": "guinea pig",
        "n02786058": "Band Aid",
        "n03384352": "forklift",
        "n03814639": "neck brace",
        "n04116512": "rubber eraser",
        "n04404412": "television",
        "n01697457": "African crocodile",
        "n02101006": "Gordon setter",
        "n02403003": "ox",
        "n02790996": "barbell",
        "n03445924": "golfcart",
        "n03825788": "nipple",
        "n04118776": "rule",
        "n04417672": "thatch",
        "n01704323": "triceratops",
        "n02102318": "cocker spaniel",
        "n02408429": "water buffalo",
        "n02916936": "bulletproof vest",
        "n03447721": "gong",
        "n03854065": "organ",
        "n04125021": "safe",
        "n04525305": "vending machine",
        "n01807496": "partridge",
        "n02105641": "Old English sheepdog",
        "n02412080": "ram",
        "n02966687": "carpenter's kit",
        "n03481172": "hammer",
        "n03873416": "paddle",
        "n04162706": "seat belt",
        "n04557648": "water bottle",
        "n01817953": "African grey",
        "n02108915": "French bulldog",
        "n02480495": "orangutan",
        "n02977058": "cash machine",
        "n03594734": "jean",
        "n03961711": "plate rack",
        "n04192698": "shield",
        "n04590129": "window shade",
        "n01855032": "red-breasted merganser",
        "n02110185": "Siberian husky",
        "n02484975": "guenon",
        "n03000684": "chain saw",
        "n03633091": "ladle",
        "n03976467": "Polaroid camera",
        "n04235860": "sleeping bag",
        "n04612504": "yawl",
        "n01945685": "slug",
        "n02110627": "affenpinscher",
        "n02500267": "indri",
        "n03016953": "chiffonier",
        "n03657121": "lens cap",
        "n04004767": "printer",
        "n04254120": "soap dispenser",
        "n07583066": "guacamole",
        "n02017213": "European gallinule",
        "n02115913": "dhole",
        "n02504458": "African elephant",
        "n03017168": "chime",
        "n03721384": "marimba",
        "n04019541": "puck",
        "n04311004": "steel arch bridge",
        "n07695742": "pretzel",
        "n02085782": "Japanese spaniel",
        "n02132136": "brown bear",
        "n02666196": "abacus",
        "n03110669": "cornet",
        "n03743016": "megalith",
        "n04023962": "punch bag",
        "n04350905": "suit",
        "n07875152": "potpie",
        "n02086910": "papillon",
        "n02229544": "cricket",
        "n02676566": "acoustic guitar",
        "n03126707": "crane",
        "n03773504": "missile",
        "n04026417": "purse",
        "n04355338": "sundial",
        "n02088632": "bluetick",
        "n02281787": "lycaenid",
        "n02708093": "analog clock",
        "n03201208": "dining table",
        "n03777568": "Model T",
        "n04044716": "radio telescope",
        "n04371430": "swimming trunks",
        "n02089973": "English foxhound",
        "n02317335": "starfish",
        "n02776631": "bakery",
        "n03290653": "entertainment center",
        "n03782006": "monitor",
        "n04074963": "remote control",
        "n04372370": "swing",
        "n02091635": "otterhound",
        "n02363005": "beaver",
        "n02783161": "ballpoint",
        "n03297495": "espresso maker",
        "n03792972": "mountain tent",
        "n04111531": "rotisserie",
        "n04376876": "syringe"
        }

        self.classnames = [labels[class_name] for class_name in self.classnames]