# From https://github.com/lvyilin/pytorch-fgvc-dataset/blob/master/aircraft.py
import numpy as np
import os
from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torchvision.datasets.utils import extract_archive


class FGVCAircraft(VisionDataset):
    """`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
    Args:
        root (string): Root directory of the dataset.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        class_type (string, optional): choose from ('variant', 'family', 'manufacturer').
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

    url = "http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
    class_types = ("variant", "family", "manufacturer")
    splits = ("train", "val", "trainval", "test")
    img_folder = os.path.join("fgvc-aircraft-2013b", "data", "images")

    def __init__(
        self,
        root,
        train=True,
        class_type="variant",
        transform=None,
        target_transform=None,
        download=False,
    ):
        super(FGVCAircraft, self).__init__(
            root, transform=transform, target_transform=target_transform
        )
        split = "trainval" if train else "test"
        if split not in self.splits:
            raise ValueError(
                'Split "{}" not found. Valid splits are: {}'.format(
                    split,
                    ", ".join(self.splits),
                )
            )
        if class_type not in self.class_types:
            raise ValueError(
                'Class type "{}" not found. Valid class types are: {}'.format(
                    class_type,
                    ", ".join(self.class_types),
                )
            )

        self.class_type = class_type
        self.split = split
        self.classes_file = os.path.join(
            self.root,
            "fgvc-aircraft-2013b",
            "data",
            "images_%s_%s.txt" % (self.class_type, self.split),
        )

        if download:
            self.download()

        (image_ids, targets, classes, class_to_idx) = self.find_classes()
        samples = self.make_dataset(image_ids, targets)

        self.loader = default_loader

        self.samples = samples
        self.classes = classes
        self.class_to_idx = class_to_idx

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.samples)

    def _check_exists(self):
        return os.path.exists(
            os.path.join(self.root, self.img_folder)
        ) and os.path.exists(self.classes_file)

    def download(self):
        if self._check_exists():
            return

        # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
        print("Downloading %s..." % self.url)
        tar_name = self.url.rpartition("/")[-1]
        download_url(self.url, root=self.root, filename=tar_name)
        tar_path = os.path.join(self.root, tar_name)
        print("Extracting %s..." % tar_path)
        extract_archive(tar_path)
        print("Done!")

    def find_classes(self):
        # read classes file, separating out image IDs and class names
        image_ids = []
        targets = []
        with open(self.classes_file, "r") as f:
            for line in f:
                split_line = line.split(" ")
                image_ids.append(split_line[0])
                targets.append(" ".join(split_line[1:]))

        # index class names
        classes = np.unique(targets)
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        targets = [class_to_idx[c] for c in targets]

        return image_ids, targets, classes, class_to_idx

    def make_dataset(self, image_ids, targets):
        assert len(image_ids) == len(targets)
        images = []
        for i in range(len(image_ids)):
            item = (
                os.path.join(self.root, self.img_folder, "%s.jpg" % image_ids[i]),
                targets[i],
            )
            images.append(item)
        return images
