from torchvision.datasets.vision import VisionDataset
import glob, os
from PIL import Image
import numpy as np
import random


class ImageNet30(VisionDataset):
    class_paths = [
        ("acorn", "n12267677"),
        ("airliner", "n02690373"),
        ("ambulance", "n02701002"),
        ("american_alligator", "n01698640"),
        ("banjo", "n02787622"),
        ("barn", "n02793495"),
        ("bikini", "n02837789"),
        ("digital_clock", "n03196217"),
        ("dragonfly", "n02268443"),
        ("dumbbell", "n03255030"),
        ("forklift", "n03384352"),
        ("goblet", "n03443371"),
        ("grand_piano", "n03452741"),
        ("hotdog", "n07697537"),
        ("hourglass", "n03544143"),
        ("manhole_cover", "n03717622"),
        ("mosque", "n03788195"),
        ("nail", "n03804744"),
        ("parking_meter", "n03891332"),
        ("pillow", "n03938244"),
        ("revolver", "n04086273"),
        ("rotary_dial_telephone", "n03187595"),
        ("schooner", "n04147183"),
        ("snowmobile", "n04252077"),
        ("soccer_ball", "n04254680"),
        ("stingray", "n01498041"),
        ("strawberry", "n07745940"),
        ("tank", "n04389033"),
        ("toaster", "n04442312"),
        ("volcano", "n09472597"),
    ]

    def __init__(self, root, transform=None, target_transform=None, train=True):
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.data = []
        self.targets = []
        for i, (cls, cls_dir) in enumerate(self.class_paths):
            if train:
                files = glob.glob(
                    os.path.join(self.root, "one_class_train", cls, f"{cls_dir}*")
                )
            else:
                files = glob.glob(
                    os.path.join(self.root, "one_class_test", cls, cls_dir, "*")
                )
            self.data += files
            self.targets += [i for _ in range(len(files))]

        # Randomly shuffle validation set to not have contiguous labels, which can result in errors during sanity checks
        idx = np.arange(len(self.data))
        np.random.default_rng(0).shuffle(idx)
        self.data = np.array(self.data)[idx]
        self.targets = np.array(self.targets)[idx]

    def __getitem__(self, index):
        # Implement the logic to retrieve and preprocess a single data sample here
        path = self.data[index]
        sample = Image.open(path).convert("RGB")
        if self.transform is not None:
            sample = self.transform(sample)
        label = self.targets[index]
        return sample, label

    def __len__(self):
        # Implement the logic to return the total number of data samples here
        return len(self.data)
