import sys
import os
import os.path as pt
import pandas as pd
from typing import List, Tuple, Callable, Union

import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from torchvision.datasets.folder import DatasetFolder, default_loader, IMG_EXTENSIONS
from torchvision.datasets.imagenet import ImageFolder
from torchvision.datasets.imagenet import verify_str_arg, load_meta_file
from torchvision.datasets.utils import download_and_extract_archive

from xad.datasets.bases import TorchvisionDataset
from xad.utils.logger import Logger


class ADGTSDB(TorchvisionDataset):
    base_folder = 'gtsdb'  # appended to root directory as a subdirectory
    classes = [
        "speed_limit_20",  # 00 --- 004 samples
        "speed_limit_30",  # 01 --- 079 samples
        "speed_limit_50",  # 02 --- 081 samples
        "speed_limit_60",  # 03 --- 030 samples
        "speed_limit_70",  # 04 --- 068 samples
        "speed_limit_80",  # 05 --- 053 samples
        "restriction_ends_80",  # 06 --- 019 samples
        "speed_limit_100",  # 07 --- 041 samples
        "speed_limit_120",  # 08 --- 057 samples
        "no_overtaking",  # 09 --- 041 samples
        "no_overtaking_trucks",  # 10 --- 080 samples
        "priority_next_intersection",  # 11 --- 038 samples
        "priority_road",  # 12 --- 085 samples
        "give_way",  # 13 --- 083 samples
        "stop",  # 14 --- 032 samples
        "no_traffic_both_ways",  # 15 --- 015 samples
        "no_trucks",  # 16 --- 008 samples
        "no_entry",  # 17 --- 029 samples
        "danger",  # 18 --- 038 samples
        "bend_left",  # 19 --- 002 samples
        "bend_right",  # 20 --- 009 samples
        "bend",  # 21 --- 005 samples
        "uneven_road",  # 22 --- 013 samples
        "slippery_road",  # 23 --- 020 samples
        "road_narrows",  # 24 --- 005 samples
        "construction",  # 25 --- 031 samples
        "traffic_signal",  # 26 --- 018 samples
        "pedestrian_crossing",  # 27 --- 003 samples
        "school_crossing",  # 28 --- 014 samples
        "cycles_crossing",  # 29 --- 005 samples
        "snow",  # 30 --- 016 samples
        "animals",  # 31 --- 002 samples
        "restriction_ends",  # 32 --- 008 samples
        "go_right",  # 33 --- 016 samples
        "go_left",  # 34 --- 012 samples
        "go_straight",  # 35 --- 020 samples
        "go_right_or_straight",  # 36 --- 009 samples
        "go_left_or_straight",  # 37 --- 002 samples
        "keep_right",  # 38 --- 088 samples
        "keep_left",  # 39 --- 006 samples
        "roundabout",  # 40 --- 010 samples
        "restriction_ends_overtaking",  # 41 --- 007 samples
        "restriction_ends_overtaking_trucks",  # 42 --- 011 samples
    ]

    def __init__(self, root: str, normal_classes: List[int], nominal_label: int,
                 train_transform: transforms.Compose, test_transform: transforms.Compose,
                 raw_shape: Tuple[int, int, int], logger: Logger = None, limit_samples: Union[int, List[int]] = np.infty,
                 **kwargs):
        """
        AD dataset for German Traffic Sign Detection Benchmark. Implements :class:`xad.datasets.bases.TorchvisionDataset`.
        """
        root = pt.join(root, self.base_folder)
        super().__init__(
            root, normal_classes, nominal_label, train_transform, test_transform, len(self.classes),
            raw_shape, logger, limit_samples,
            **kwargs
        )

        self._train_set = GTSDB(
            self.root, split='train', transform=self.train_transform, target_transform=self.target_transform,
            logger=logger
        )
        self._train_set = self.create_subset(self._train_set, self._train_set.targets)

        # balance normal classes
        n_samples = [(self._train_set.dataset.targets[self._train_set.indices] == cls).sum() for cls in normal_classes]
        picks = []
        for cls, n in zip(normal_classes, n_samples):
            subset_ids_for_cls = [ind for ind in self._train_set.indices if self._train_set.dataset.targets[ind] == cls]
            picks.extend(sorted(subset_ids_for_cls))
            need = max(n_samples).item() - n
            while need > 0:
                p = np.random.choice(subset_ids_for_cls, min(need, len(subset_ids_for_cls)), replace=False).tolist()
                picks.extend(sorted(p))
                need -= len(p)
        self._train_set.indices = sorted(picks)

        self._test_set = GTSDB(
            root=self.root, split='val', transform=self.test_transform, target_transform=self.target_transform,
            logger=logger
        )
        self._test_set = Subset(self._test_set, list(range(len(self._test_set))))  # create improper subset with all indices

    def _get_raw_train_set(self):
        train_set = GTSDB(
            self.root, split='train',
            transform=transforms.Compose([transforms.Resize(self.raw_shape[-1]), transforms.ToTensor(), ]),
            target_transform=self.target_transform, logger=self.logger
        )
        return Subset(
            train_set,
            np.argwhere(
                np.isin(np.asarray(train_set.targets), self.normal_classes)
            ).flatten().tolist()
        )


class GTSDB(ImageFolder):
    url = "https://sid.erda.dk/public/archives/ff17dc924eba88d5d01a807357d6614c/FullIJCNN2013.zip"

    def __init__(self, root: str, split: str = 'train', transform: Callable = None, target_transform: Callable = None,
                 **kwargs):
        self.logger = kwargs.pop('logger', None)
        super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform, **kwargs)
        self.split = verify_str_arg(split, "split", ("train", "val"))
        self.loader = kwargs.get('loader', default_loader)
        self.extensions = kwargs.get('extensions', None)
        self._download()

    def _download(self):
        if self._check_integrity():
            print('Files already downloaded and verified')
            return
        else:
            download_and_extract_archive(self.url, pt.join(self.root, ))
            assert self._check_integrity(), 'GTSDB is corrupted. Please redownload.'

    def _check_integrity(self):
        if not pt.exists(pt.join(self.root, 'FullIJCNN2013')) or not self._load():
            return False

        if self.data.shape[1:] != (32, 32, 3):
            print('Data is of wrong shape!', file=sys.stderr)
            return False
        return True

    def _load(self):
        for file in os.listdir(pt.join(self.root, 'FullIJCNN2013')):  # remove complete street pictures
            if file.endswith('.ppm'):
                os.remove(pt.join(self.root, 'FullIJCNN2013', file))

        self.class_folders, class_to_idx = self.find_classes(pt.join(self.root, 'FullIJCNN2013'))
        samples = self.make_dataset(
            pt.join(self.root, 'FullIJCNN2013'), class_to_idx, IMG_EXTENSIONS, None
        )
        samples, targets = zip(*samples)
        samples, targets = np.asarray(samples), np.asarray(targets)
        xtr, xte = train_test_split(
            np.arange(len(targets)), test_size=0.2, random_state=8121943, stratify=targets
        )
        samples = samples[xtr] if self.split == 'train' else samples[xte]
        self.data = np.stack([np.asarray(self.loader(s).resize((32, 32))) for s in samples])
        self.targets = targets[xtr] if self.split == 'train' else targets[xte]

        return True

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int, int]:
        img, target = self.data[index], self.targets[index]
        if self.transform is None or isinstance(self.transform, transforms.Compose) and len(self.transform.transforms) == 0:
            img = torch.from_numpy(img).float().div(255).permute(2, 0, 1)
        else:
            img = Image.fromarray(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.transform is not None:
            img = self.transform(img)
        return img, target, index

    def __len__(self) -> int:
        return self.data.shape[0]

