# Wraping for the StanfordCars dataset
import os
from typing import Callable, Optional

from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms

# StanfordCars dataset
# Since the original source is not available, download the dataset from https://www.kaggle.com/datasets/jutrera/stanford-car-dataset-by-classes-folder


class StanfordCars(ImageFolder):

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:

        self.root = os.path.expanduser(root) + '/StanfordCars/'
        path = self.root + "train" if train else self.root + "test"

        super().__init__(path,
                         transform=transforms.ToTensor()
                         if transform is None else transform,
                         target_transform=target_transform)

        self.classes_names = self.classes

        self.mean, self.std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        if isinstance(self.transform, transforms.Compose):
            self.transform.transforms.append(
                transforms.Normalize(self.mean, self.std))


if __name__ == "__main__":
    dataset = StanfordCars(root="data", train=False, download=True)
    print(len(dataset), len(dataset.classes_names))
