import torch
import torchvision

from typing import Any, Callable, Optional, Tuple


class PreProcessCIFAR10(torchvision.datasets.CIFAR10):
    mean = [0.49139967861519745, 0.4821584083946076, 0.44653091444546616]
    std = [0.2470322324632823, 0.24348512800005553, 0.2615878417279641]

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False
    ) -> None:
        super().__init__(root, train, transform, target_transform, download)

        # Preprocess data
        self.data = torch.from_numpy(self.data) / 255
        for i in range(3):
            slice = self.data[:, :, :, i]
            self.data[:, :, :, i] = (slice - self.mean[i]) / self.std[i]
        self.data = torch.permute(self.data, (0, 3, 1, 2)).cuda()

        # Preporcess targets
        self.targets = torch.as_tensor(self.targets).to(torch.int64).cuda()

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        return self.data[index], self.targets[index]
