import pathlib
from typing import Callable, Optional, Any, Tuple

from PIL import Image

from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg
from torchvision.datasets.vision import VisionDataset


class StanfordCars(VisionDataset):
    """`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset

    The Cars dataset contains 16,185 images of 196 classes of cars. The data is
    split into 8,144 training images and 8,041 testing images, where each class
    has been split roughly in a 50-50 split

    .. note::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

    Args:
        root (string): Root directory of dataset
        split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If True, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again."""

    def __init__(
        self,
        root: str,
        split: str = "train",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
        prompt_template = "A photo of a {}."
    ) -> None:

        try:
            import scipy.io as sio
        except ImportError:
            raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")

        super().__init__(root, transform=transform, target_transform=target_transform)

        self._split = verify_str_arg(split, "split", ("train", "test"))
        self._base_folder = pathlib.Path(root) / "stanford_cars"
        devkit = self._base_folder / "devkit"

        if self._split == "train":
            self._annotations_mat_path = devkit / "cars_train_annos.mat"
            self._images_base_path = self._base_folder / "cars_train"
        else:
            self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
            self._images_base_path = self._base_folder / "cars_test"

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found. You can use download=True to download it")

        self._samples = [
            (
                str(self._images_base_path / annotation["fname"]),
                annotation["class"] - 1,  # Original target mapping  starts from 1, hence -1
            )
            for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
        ]

        self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        self.prompt_template = prompt_template
        self.clip_prompts = [ 
            prompt_template.format(label[:-5].lower().replace('_', ' ').replace('-', ' ')) \
            for label in self.classes
        ]
        # add years
        self.classes = ['2000 AM General Hummer SUV', '2012 Acura RL Sedan', '2012 Acura TL Sedan', '2008 Acura TL Type-S', '2012 Acura TSX Sedan', '2001 Acura Integra Type R', '2012 Acura ZDX Hatchback', '2012 Aston Martin V8 Vantage Convertible', '2012 Aston Martin V8 Vantage Coupe', '2012 Aston Martin Virage Convertible', '2012 Aston Martin Virage Coupe', '2008 Audi RS 4 Convertible', '2012 Audi A5 Coupe', '2012 Audi TTS Coupe', '2012 Audi R8 Coupe', '1994 Audi V8 Sedan', '1994 Audi 100 Sedan', '1994 Audi 100 Wagon', '2011 Audi TT Hatchback', '2011 Audi S6 Sedan', '2012 Audi S5 Convertible', '2012 Audi S5 Coupe', '2012 Audi S4 Sedan', '2007 Audi S4 Sedan', '2012 Audi TT RS Coupe', '2012 BMW ActiveHybrid 5 Sedan', '2012 BMW 1 Series Convertible', '2012 BMW 1 Series Coupe', '2012 BMW 3 Series Sedan', '2012 BMW 3 Series Wagon', '2007 BMW 6 Series Convertible', '2007 BMW X5 SUV', '2012 BMW X6 SUV', '2012 BMW M3 Coupe', '2010 BMW M5 Sedan', '2010 BMW M6 Convertible', '2012 BMW X3 SUV', '2012 BMW Z4 Convertible', '2012 Bentley Continental Supersports Conv. Convertible', '2009 Bentley Arnage Sedan', '2011 Bentley Mulsanne Sedan', '2012 Bentley Continental GT Coupe', '2007 Bentley Continental GT Coupe', '2007 Bentley Continental Flying Spur Sedan', '2009 Bugatti Veyron 16.4 Convertible', '2009 Bugatti Veyron 16.4 Coupe', '2012 Buick Regal GS', '2007 Buick Rainier SUV', '2012 Buick Verano Sedan', '2012 Buick Enclave SUV', '2012 Cadillac CTS-V Sedan', '2012 Cadillac SRX SUV', '2007 Cadillac Escalade EXT Crew Cab', '2012 Chevrolet Silverado 1500 Hybrid Crew Cab', '2012 Chevrolet Corvette Convertible', '2012 Chevrolet Corvette ZR1', '2007 Chevrolet Corvette Ron Fellows Edition Z06', '2012 Chevrolet Traverse SUV', '2012 Chevrolet Camaro Convertible', '2010 Chevrolet HHR SS', '2007 Chevrolet Impala Sedan', '2012 Chevrolet Tahoe Hybrid SUV', '2012 Chevrolet Sonic Sedan', '2007 Chevrolet Express Cargo Van', '2012 Chevrolet Avalanche Crew Cab', '2010 Chevrolet Cobalt SS', '2010 Chevrolet Malibu Hybrid Sedan', '2009 Chevrolet TrailBlazer SS', '2012 Chevrolet Silverado 2500HD Regular Cab', '2007 Chevrolet Silverado 1500 Classic Extended Cab', '2007 Chevrolet Express Van', '2007 Chevrolet Monte Carlo Coupe', '2007 Chevrolet Malibu Sedan', '2012 Chevrolet Silverado 1500 Extended Cab', '2012 Chevrolet Silverado 1500 Regular Cab', '2009 Chrysler Aspen SUV', '2010 Chrysler Sebring Convertible', '2012 Chrysler Town and Country Minivan', '2010 Chrysler 300 SRT-8', '2008 Chrysler Crossfire Convertible', '2008 Chrysler PT Cruiser Convertible', '2002 Daewoo Nubira Wagon', '2012 Dodge Caliber Wagon', '2007 Dodge Caliber Wagon', '1997 Dodge Caravan Minivan', '2010 Dodge Ram Pickup 3500 Crew Cab', '2009 Dodge Ram Pickup 3500 Quad Cab', '2009 Dodge Sprinter Cargo Van', '2012 Dodge Journey SUV', '2010 Dodge Dakota Crew Cab', '2007 Dodge Dakota Club Cab', '2008 Dodge Magnum Wagon', '2011 Dodge Challenger SRT8', '2012 Dodge Durango SUV', '2007 Dodge Durango SUV', '2012 Dodge Charger Sedan', '2009 Dodge Charger SRT-8', '1998 Eagle Talon Hatchback', '2012 FIAT 500 Abarth', '2012 FIAT 500 Convertible', '2012 Ferrari FF Coupe', '2012 Ferrari California Convertible', '2012 Ferrari 458 Italia Convertible', '2012 Ferrari 458 Italia Coupe', '2012 Fisker Karma Sedan', '2012 Ford F-450 Super Duty Crew Cab', '2007 Ford Mustang Convertible', '2007 Ford Freestar Minivan', '2009 Ford Expedition EL SUV', '2012 Ford Edge SUV', '2011 Ford Ranger SuperCab', '2006 Ford GT Coupe', '2012 Ford F-150 Regular Cab', '2007 Ford F-150 Regular Cab', '2007 Ford Focus Sedan', '2012 Ford E-Series Wagon Van', '2012 Ford Fiesta Sedan', '2012 GMC Terrain SUV', '2012 GMC Savana Van', '2012 GMC Yukon Hybrid SUV', '2012 GMC Acadia SUV', '2012 GMC Canyon Extended Cab', '1993 Geo Metro Convertible', '2010 HUMMER H3T Crew Cab', '2009 HUMMER H2 SUT Crew Cab', '2012 Honda Odyssey Minivan', '2007 Honda Odyssey Minivan', '2012 Honda Accord Coupe', '2012 Honda Accord Sedan', '2012 Hyundai Veloster Hatchback', '2012 Hyundai Santa Fe SUV', '2012 Hyundai Tucson SUV', '2012 Hyundai Veracruz SUV', '2012 Hyundai Sonata Hybrid Sedan', '2007 Hyundai Elantra Sedan', '2012 Hyundai Accent Sedan', '2012 Hyundai Genesis Sedan', '2012 Hyundai Sonata Sedan', '2012 Hyundai Elantra Touring Hatchback', '2012 Hyundai Azera Sedan', '2012 Infiniti G Coupe IPL', '2011 Infiniti QX56 SUV', '2008 Isuzu Ascender SUV', '2012 Jaguar XK XKR', '2012 Jeep Patriot SUV', '2012 Jeep Wrangler SUV', '2012 Jeep Liberty SUV', '2012 Jeep Grand Cherokee SUV', '2012 Jeep Compass SUV', '2008 Lamborghini Reventon Coupe', '2012 Lamborghini Aventador Coupe', '2012 Lamborghini Gallardo LP 570-4 Superleggera', '2001 Lamborghini Diablo Coupe', '2012 Land Rover Range Rover SUV', '2012 Land Rover LR2 SUV', '2011 Lincoln Town Car Sedan', '2012 MINI Cooper Roadster Convertible', '2012 Maybach Landaulet Convertible', '2011 Mazda Tribute SUV', '2012 McLaren MP4-12C Coupe', '1993 Mercedes-Benz 300-Class Convertible', '2012 Mercedes-Benz C-Class Sedan', '2009 Mercedes-Benz SL-Class Coupe', '2012 Mercedes-Benz E-Class Sedan', '2012 Mercedes-Benz S-Class Sedan', '2012 Mercedes-Benz Sprinter Van', '2012 Mitsubishi Lancer Sedan', '2012 Nissan Leaf Hatchback', '2012 Nissan NV Passenger Van', '2012 Nissan Juke Hatchback', '1998 Nissan 240SX Coupe', '1999 Plymouth Neon Coupe', '2012 Porsche Panamera Sedan', '2012 Ram C/V Cargo Van Minivan', '2012 Rolls-Royce Phantom Drophead Coupe Convertible', '2012 Rolls-Royce Ghost Sedan', '2012 Rolls-Royce Phantom Sedan', '2012 Scion xD Hatchback', '2009 Spyker C8 Convertible', '2009 Spyker C8 Coupe', '2007 Suzuki Aerio Sedan', '2012 Suzuki Kizashi Sedan', '2012 Suzuki SX4 Hatchback', '2012 Suzuki SX4 Sedan', '2012 Tesla Model S Sedan', '2012 Toyota Sequoia SUV', '2012 Toyota Camry Sedan', '2012 Toyota Corolla Sedan', '2012 Toyota 4Runner SUV', '2012 Volkswagen Golf Hatchback', '1991 Volkswagen Golf Hatchback', '2012 Volkswagen Beetle Hatchback', '2012 Volvo C30 Hatchback', '1993 Volvo 240 Sedan', '2007 Volvo XC90 SUV', '2012 smart fortwo Convertible']

    def __len__(self) -> int:
        return len(self._samples)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        """Returns pil_image and class_id for given index"""
        image_path, target = self._samples[idx]
        pil_image = Image.open(image_path).convert("RGB")

        if self.transform is not None:
            pil_image = self.transform(pil_image)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return pil_image, target

    def download(self) -> None:
        if self._check_exists():
            return

        download_and_extract_archive(
            url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
            download_root=str(self._base_folder),
            md5="c3b158d763b6e2245038c8ad08e45376",
        )
        if self._split == "train":
            download_and_extract_archive(
                url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
                download_root=str(self._base_folder),
                md5="065e5b463ae28d29e77c1b4b166cfe61",
            )
        else:
            download_and_extract_archive(
                url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
                download_root=str(self._base_folder),
                md5="4ce7ebf6a94d07f1952d94dd34c4d501",
            )
            download_url(
                url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
                root=str(self._base_folder),
                md5="b0a2b23655a3edd16d84508592a98d10",
            )

    def _check_exists(self) -> bool:
        if not (self._base_folder / "devkit").is_dir():
            return False

        return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
