import os
import random

import torch
from torchvision.datasets import CIFAR100 as PyTorchCIFAR100
from torch.utils.data import DataLoader, Subset

from PIL import Image

class MyPyTorchCIFAR100(PyTorchCIFAR100):
    def __init__(self, root, download, train, transform):
        super().__init__(root=root, download=download, train=train, transform=transform)

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index

class CIFAR100:
    def __init__(self,
                 preprocess,
                 location=os.path.expanduser('./data'),
                 batch_size=128,
                 num_workers=16,
                 random_seed=42):
        self.random_seed= 42
        self.train_dataset = MyPyTorchCIFAR100(
            root=location, download=True, train=True, transform=preprocess
        )

        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset, batch_size=batch_size, num_workers=num_workers
        )

        self.test_dataset = MyPyTorchCIFAR100(
            root=location, download=True, train=False, transform=preprocess
        )

        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
        )

        self.test_loader_shuffle = torch.utils.data.DataLoader(
            self.test_dataset,
            shuffle=True,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True
        )

        self.fast_test_dataset, self.fast_test_loader = self.get_fast_test_loader(batch_size, num_workers)

        self.classnames = self.test_dataset.classes

    def get_fast_test_loader(self, batch_size, num_workers):
        """Create a DataLoader for fast evaluation with 10% randomly selected test samples."""
        # Get 10% of the dataset
        random.seed(self.random_seed)
        test_size = len(self.test_dataset)
        fast_test_size = int(test_size * 0.1)

        # Randomly select indices for the subset
        indices = random.sample(range(test_size), fast_test_size)

        # Create a Subset for the fast test dataset
        fast_test_dataset = Subset(self.test_dataset, indices)

        # Return a DataLoader for the fast test dataset
        return fast_test_dataset, DataLoader(fast_test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)



