import csv
from pathlib import Path
from typing import List, Tuple
from PIL import Image

import torch
from torch.utils.data import Dataset
from typing import List, Optional, Tuple, Union

from torchvision.datasets.folder import default_loader
from torchvision.transforms import ToTensor

from avalanche.benchmarks.datasets import (
    SimpleDownloadableDataset,
    default_dataset_location,
)

class tinyimg(SimpleDownloadableDataset):
    """Tiny Imagenet Pytorch Dataset"""

    filename = (
        "tiny-imagenet-200.zip",
        "http://cs231n.stanford.edu/tiny-imagenet-200.zip",
    )
    md5 = "90528d7ca1a48142e341f4ef8d21d0de"

    def __init__(
        self,
        root: Optional[Union[str, Path]] = None,
        *,
        train: bool = True,
        transform=None,
        target_transform=None,
        loader=default_loader,
        download=True
    ):
        """
        Creates an instance of the Tiny Imagenet dataset.

        :param root: folder in which to download dataset. Defaults to None,
            which means that the default location for 'tinyimagenet' will be
            used.
        :param train: True for training set, False for test set.
        :param transform: Pytorch transformation function for x.
        :param target_transform: Pytorch transformation function for y.
        :param loader: the procedure to load the instance from the storage.
        :param bool download: If True, the dataset will be  downloaded if
            needed.
        """

        if root is None:
            root = default_dataset_location("tinyimagenet")

        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.loader = loader

        super(tinyimg, self).__init__(
            root, self.filename[1], self.md5, download=download, verbose=True
        )

        self._load_dataset()

    def _load_metadata(self) -> bool:
        self.data_folder = self.root / "tiny-imagenet-200"

        self.label2id, self.id2label = tinyimg.labels2dict(self.data_folder)
        self.data, self.targets = self.load_data()
        self.global_indices = list(range(len(self.data)))

        return True

    @staticmethod
    def labels2dict(data_folder: Path):
        """
        Returns dictionaries to convert class names into progressive ids
        and viceversa.

        :param data_folder: The root path of tiny imagenet
        :returns: label2id, id2label: two Python dictionaries.
        """

        label2id = {}
        id2label = {}

        with open(str(data_folder / "wnids.txt"), "r") as f:
            reader = csv.reader(f)
            curr_idx = 0
            for ll in reader:
                if ll[0] not in label2id:
                    label2id[ll[0]] = curr_idx
                    id2label[curr_idx] = ll[0]
                    curr_idx += 1

        return label2id, id2label

    def load_data(self):
        """
        Load all images paths and targets.

        :return: train_set, test_set: (train_X_paths, train_y).
        """

        data: Tuple[List[Path], List[int]] = ([], [])

        classes = list(range(200))
        for class_id in classes:
            class_name = self.id2label[class_id]

            if self.train:
                X = self.get_train_images_paths(class_name)
                Y = [class_id] * len(X)
            else:
                # test set
                X = self.get_test_images_paths(class_name)
                Y = [class_id] * len(X)

            data[0].extend(X)
            data[1].extend(Y)

        return data

    def get_train_images_paths(self, class_name) -> List[Path]:
        """
        Gets the training set image paths.

        :param class_name: names of the classes of the images to be
            collected.
        :returns img_paths: list of strings (paths)
        """
        train_img_folder: Path = self.data_folder / "train" / class_name / "images"

        img_paths = sorted(
            [f for f in train_img_folder.iterdir() if f.is_file()],
            key=lambda x: x.name  # Sort by filename
        )
        return img_paths

    def get_test_images_paths(self, class_name) -> List[Path]:
        """
        Gets the test set image paths

        :param class_name: names of the classes of the images to be
            collected.
        :returns img_paths: list of strings (paths)
        """

        val_img_folder: Path = self.data_folder / "val" / class_name / "images"
        annotations_file: Path = self.data_folder / "val" / "val_annotations.txt"

        valid_names = []

        # filter validation images by class using appropriate file
        with open(str(annotations_file), "r") as f:
            reader = csv.reader(f, dialect="excel-tab")
            for ll in reader:
                if ll[1] == class_name:
                    valid_names.append(ll[0])

        img_paths = sorted(val_img_folder / f for f in valid_names)

        return img_paths

    def __len__(self):
        """Returns the length of the set"""
        return len(self.data)

    def __getitem__(self, index):
        """Returns the index-th x, y pattern of the set"""

        path = self.data[index]
        target = int(self.targets[index])
        img = self.loader(path)

        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            target = self.target_transform(target)

        global_index = self.global_indices[index]
        return global_index, img, target