import itertools
from copy import deepcopy
from typing import List, Union, Tuple

import numpy as np
import os
import os.path as pt
import json
from torchvision.transforms import Resize, Compose

from xad.datasets.bases import TorchvisionDataset, CombinedDataset
from xad.datasets.cifar import ADCIFAR10, ADCIFAR100
from xad.datasets.imagenet import ADImageNet, ADImageNet21k, ADImageNet21kSubSet, ADImagenetNeighbors, MyImageNet, ADImagenetNeighborsDebug
from xad.datasets.imagenetoe import ADImageNetOE
from xad.datasets.mnist import ADMNIST, ADEMNIST
from xad.datasets.mvtec import ADMvTec
from xad.utils.logger import Logger
from xad.datasets.oe_concat import ConcatOEDataset
from xad.datasets.confetti_noise import ADConfettiDataset
from xad.datasets.colored_mnist import ADColoredMNIST, ADColoredEMNIST, ADThreeEMNIST
from xad.datasets.gtsdb import ADGTSDB
from xad.utils.training_tools import int_set_to_str


if os.path.exists(pt.abspath(pt.join(__file__, '..', 'meta', 'imagenet1000_clsidx_to_labels.json'))):
    with open(pt.abspath(pt.join(__file__, '..', 'meta', 'imagenet1000_clsidx_to_labels.json')), 'r') as reader:
        imagenet1k_str_labels = json.load(reader)
        imagenet1k_str_labels = [
            v.split(",")[0] for k, v in sorted(imagenet1k_str_labels.items(), key=lambda x: int(x[0]))
        ]
else:
    imagenet1k_str_labels = [str(i) for i in range(1000)]



DS_CHOICES = {  # list of implemented datasets (most can also be used as OE)
    'cifar10': {
        'class': ADCIFAR10, 'default_size': 32, 'no_classes': 10, 'oe_only': False,
        'str_labels':  ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
    },
    'imagenet': {
        'class': ADImageNet, 'default_size': 256, 'no_classes': 30, 'oe_only': False,
        'str_labels': deepcopy(ADImageNet.ad_classes),
    },
    'imagenet_neighbours': {
        'class': ADImagenetNeighbors, 'default_size': 256, 'no_classes': 1000, 'oe_only': False,
        'str_labels':imagenet1k_str_labels
    },
    'imagenet_neighbours_debug': {
        'class': ADImagenetNeighborsDebug, 'default_size': 256, 'no_classes': 1000, 'oe_only': False,
        'str_labels':imagenet1k_str_labels
    },
    'cifar100': {
        'class': ADCIFAR100, 'default_size': 32, 'no_classes': 100, 'oe_only': False,
        'str_labels': [
            'beaver', 'dolphin', 'otter', 'seal', 'whale',
            'aquarium_fish', 'flatfish', 'ray', 'shark', 'trout',
            'orchid', 'poppy', 'rose', 'sunflower', 'tulip',
            'bottle', 'bowl', 'can', 'cup', 'plate',
            'apple', 'mushroom', 'orange', 'pear', 'sweet_pepper',
            'clock', 'keyboard', 'lamp', 'telephone', 'television',
            'bed', 'chair', 'couch', 'table', 'wardrobe',
            'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',
            'bear', 'leopard', 'lion', 'tiger', 'wolf',
            'bridge', 'castle', 'house', 'road', 'skyscraper',
            'cloud', 'forest', 'mountain', 'plain', 'sea',
            'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',
            'fox', 'porcupine', 'possum', 'raccoon', 'skunk',
            'crab', 'lobster', 'snail', 'spider', 'worm',
            'baby', 'boy', 'girl', 'man', 'woman',
            'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',
            'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',
            'maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree',
            'bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train',
            'lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor'
        ]
    },
    'imagenet21k': {
        'class': ADImageNet21k, 'default_size': 256, 'no_classes': 21811, 'oe_only': False,
        'str_labels': [str(i) for i in range(21811)],  # ?
    },
    'imagenet21ksubset': {
        'class': ADImageNet21kSubSet, 'default_size': 256, 'no_classes': 21811, 'oe_only': False,
        'str_labels': [str(i) for i in range(21811)],  # ?
    },
    'mvtec': {
        'class': ADMvTec, 'default_size': 256, 'no_classes': 15, 'oe_only': False,
        'str_labels': [
            'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather',
            'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor',
            'wood', 'zipper'
        ]
    },
    'imagenetoe': {
        'class': ADImageNetOE, 'default_size': 256, 'no_classes': 1000, 'oe_only': True,
        'str_labels':  list(range(1000)),  # not required
    },
    'mnist': {
        'class': ADMNIST, 'default_size': 28, 'no_classes': 10, 'oe_only': False,
        'str_labels': [
          "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"
        ]
    },
    'emnist': {
        'class': ADEMNIST, 'default_size': 28, 'no_classes': 26, 'oe_only': False, 'str_labels': list(range(26)),  # ?
    },
    'confetti': {
        'class': ADConfettiDataset, 'default_size': 256, 'no_classes': 1, 'oe_only': True, 'str_labels': ["confetti"],
    },
    'coloredmnist': {
        'class': ADColoredMNIST, 'default_size': 28, 'no_classes': 10 * 7, 'oe_only': False,
        'str_labels': ADColoredMNIST.classes,
    },
    'coloredemnist': {
        'class': ADColoredEMNIST, 'default_size': 28, 'no_classes': 26 * 7, 'oe_only': False,
        'str_labels': [f"{i}-{j}" for j in range(26) for i in range(7)],  # ?
    },
    'threeemnist': {  # nx3x28x28 instead of nx1x28x28
        'class': ADThreeEMNIST, 'default_size': 28, 'no_classes': 26, 'oe_only': False, 'str_labels': list(range(26)),  # ?
    },
    'gtsdb': {
        'class': ADGTSDB, 'default_size': 32, 'no_classes': len(ADGTSDB.classes), 'oe_only': False,
        'str_labels': ADGTSDB.classes,
    },
}
TRAIN_NOMINAL_ID = 0
TRAIN_OE_ID = 1
TEST_NOMINAL_ID = 2
TEST_ANOMALOUS_ID = 3


def get_raw_shape(train_transform: Compose, dataset_name: str) -> Tuple[int, int, int]:
    """ detects the raw_shape of the data (i.e., the shape before clipping etc.) using the first resize transform """
    if len(train_transform.transforms) > 0 and isinstance(train_transform.transforms[0], Resize):
        t = train_transform.transforms[0]
        if isinstance(t.size, int):
            return 3, t.size, t.size
        else:
            return 3, *t.size
    else:
        size = DS_CHOICES[dataset_name]['default_size']
        return 3, size, size
    
    
def load_dataset(dataset_name: str, data_path: str, normal_classes: set[int], nominal_label: int,
                 train_transform: Compose, test_transform: Compose, logger: Logger = None,
                 oe_names: List[str] = None, oe_limit_samples: Union[int, List[int]] = np.infty,
                 oe_limit_classes: Union[int, set[int]] = np.infty,) -> TorchvisionDataset:
    """
    Prepares a dataset, includes setting up all the necessary attributes such as a list of filepaths and labels.
    Requires a list of normal classes that determines the labels and which classes are available during training.
    If OE datasets are specified, prepares a combined dataset.
    The combined dataset's test split is the test split of the normal dataset.
    The combined dataset's training split is a combination of the normal training data and the OE data.
    It provides a balanced concatenated data loader. See :class:`xad.datasets.bases.CombinedDataset`.

    @param dataset_name: Defines the normal dataset (containing also anomalous test samples). See :attr:`DS_CHOICES`.
    @param data_path: Defines the root directory for all datasets. Most of them get automatically downloaded if not present
        at this directory. Each dataset has its own subdirectory (e.g., xad/data/datasets/imagenet/).
    @param normal_classes: A set of normal classes. Normal training samples are all from these classes.
        Samples from other classes are not available during training. During testing, other classes will have anomalous labels.
    @param nominal_label: The integer defining the normal (==nominal) label. Usually 0.
    @param train_transform: preprocessing pipeline used for training, includes all kinds of image transformations.
        May contain the dummy transformation 'norm' that will be replaced with a torchvision normalization instance.
        The required mean and std of the normal training data will be extracted automatically.
    @param test_transform: preprocessing pipeline used for testing,
        includes all kinds of image transformations but no data augmentation.
        May contain the dummy transformation 'norm' that will be replaced with a torchvision normalization instance.
        The required mean and std of the normal training data will be extracted automatically.
    @param logger: Optional. Some logger instance. Is only required for logging warnings related to the datasets.
    @param oe_name: Optional. Defines the OE datasets. See method description.
    @param oe_limit_samples: Optional. If given, limits the number of different OE samples. That is, instead of using the
        complete OE dataset, creates a subset to be used as OE. If `oe_limit_samples` is an integer, samples a random subset
        with the provided size. If `oe_limit_samples` is a list of integers, create a subset with the indices provided.
    @param oe_limit_classes: Optional. If given, limits the number of different classes of OE samples. That is, instead of
        using the complete OE dataset, creates a subset consisting only of OE images from selected classes.
        If `oe_limit_classes` is an integer, randomly selects that many classes.
        If `oe_limit_classes` is a class sets, uses the specified class set from the OE dataset.
        Note that some OE dataset implementations (80MTI, ImageNet-21k) come with no classes.
        That is, they just have one pseudo-class. Limiting their classes won't work.
    """

    assert dataset_name in DS_CHOICES, f'{dataset_name} is not in {DS_CHOICES}'

    raw_shape = get_raw_shape(train_transform, dataset_name)
    normal_classes = list(normal_classes)
    oe_limit_classes = oe_limit_classes if oe_limit_classes is not None else np.infty

    def get_ds(name: str, train_classes: List[int], train_label: int, total_train_transform: Compose,
               total_test_transform: Compose, limit: Union[int, List[int]], normal_dataset: TorchvisionDataset = None,
               **kwargs):
        if normal_dataset is not None and normal_dataset.__class__.__name__ == "ADImagenetNeighbors":
            args = (
            data_path, train_classes, train_label, total_train_transform, total_test_transform,
            raw_shape, logger, limit, normal_dataset.exclude_class_labels
            )
        else:
            args = (
                data_path, train_classes, train_label, total_train_transform, total_test_transform,
                raw_shape, logger, limit
            )
        if DS_CHOICES[name]['oe_only']:
            assert normal_dataset is not None, f"{name} can only be used as OE!"
            dataset = DS_CHOICES[name]['class'](normal_dataset, *args, **kwargs)
        else:
            dataset = DS_CHOICES[name]['class'](*args, **kwargs)
        if normal_dataset is not None:  # oe case, GPU transformations need to be copied too
            if logger is not None:
                logger.print('OE GPU Transformations are copied from the normal dataset.')
            dataset.gpu_train_transform = deepcopy(Compose([
                *normal_dataset.gpu_train_transform.transforms, *dataset.gpu_train_transform.transforms
            ]))
            dataset.gpu_test_transform = deepcopy(Compose([
                *normal_dataset.gpu_test_transform.transforms, *dataset.gpu_test_transform.transforms
            ]))
        return dataset

    def do_limit_classes(oe_classes: List[int], oesetstr: str):
        if isinstance(oe_limit_classes, set):
            not_available = [o for o in oe_limit_classes if o not in oe_classes]
            if len(not_available) > 0:
                logger.warning(
                    f'Found an invalid class id in the limited oe class set ({oe_limit_classes}) for {oesetstr}! '
                    f'Classes {not_available} are not valid. Will ignore these.'
                )
            oe_classes = sorted([o for o in oe_limit_classes if o in oe_classes])
        else:
            oe_classes = sorted(np.random.choice(oe_classes, min(len(oe_classes), oe_limit_classes), False))
        return oe_classes

    if logger is not None:
        logger.print('---Loading normal dataset...')
        normal_classes_strs = [DS_CHOICES[dataset_name]['str_labels'][idx] for idx in normal_classes]
        logger.print(f"---normal classes: {normal_classes_strs}")
    dataset = get_ds(dataset_name, normal_classes, nominal_label, train_transform, test_transform, np.infty)

    if oe_names is not None and len(oe_names) > 0:
        train_label = 1 - nominal_label
        oes = []
        if len(oe_names) > 1:
            if oe_limit_classes != np.infty or oe_limit_samples < np.infty:
                raise NotImplementedError('Multiple OE datasets in combination with limits are not supported atm.')
        for i, oe_name in enumerate(oe_names):
            logger.print(
                f'---Loading OE dataset {i+1}/{len(oe_names)} {"(semi-supervision)" if dataset_name == oe_name else ""}...'
            )
            kwargs = {}
            if dataset_name == oe_name:  # semi-supervision!
                if oe_name == 'mvtec':  # in mvtec, the anomalies are not other classes but subclasses
                    if oe_limit_classes != np.infty:
                        raise ValueError(
                            'For MVTec-AD, there is one particular anomalous class per normal class. Limiting not supported.'
                        )
                    kwargs = {'defects_are_normal': False}
                    train_classes = deepcopy(normal_classes)
                else:
                    train_classes = do_limit_classes([i for i in range(no_classes(oe_name)) if i not in normal_classes], oe_name)
            else:
                train_classes = do_limit_classes(list(range(no_classes(oe_name))), oe_name)
            total_train_transform = deepcopy(dataset.train_transform)  # copy to have the same normalize etc.
            total_test_transform = deepcopy(dataset.test_transform)  # copy to have the same normalize etc.
            limit = oe_limit_samples
            oes.append(get_ds(
                oe_name, train_classes, train_label, total_train_transform, total_test_transform, limit, dataset, **kwargs
            ))
        if len(oe_names) == 1:
            oe = oes[0]
        else:
            logger.print('--Combining OE datasets...\n')
            oe = ConcatOEDataset(oes)
        dataset = CombinedDataset(dataset, oe)
    if logger is not None:
        logger.print('---Datasets loaded successfully...\n')

    return dataset


def no_classes(dataset_name: str) -> int:
    """ returns the number of classes for the given dataset """
    return DS_CHOICES[dataset_name]['no_classes']


def str_labels(dataset_name: str) -> List[str]:
    """ returns a list of class descriptions for the given dataset """
    return DS_CHOICES[dataset_name]['str_labels']


COLORED_MNIST_CSETS = {}
for cid, col in enumerate(ADColoredMNIST.colors):
    COLORED_MNIST_CSETS[f"{col}"] = set()
    for did, digit in enumerate(ADColoredMNIST.digits):
        COLORED_MNIST_CSETS[f"{col}+or+{digit}"] = set()
        COLORED_MNIST_CSETS[f"{col}+xor+{digit}"] = set()
        COLORED_MNIST_CSETS[f"{digit}"] = set()
        COLORED_MNIST_CSETS[f"{col}"].add(cid * 10 + did)
        for cid2, col2 in enumerate(ADColoredMNIST.colors):
            COLORED_MNIST_CSETS[f"{digit}"].add(cid2 * 10 + did)
            for did2, digit2 in enumerate(ADColoredMNIST.digits):
                if cid2 == cid or did2 == did:
                    COLORED_MNIST_CSETS[f"{col}+or+{digit}"].add(cid2 * 10 + did2)
                    if not (cid2 == cid and did2 == did):
                        COLORED_MNIST_CSETS[f"{col}+xor+{digit}"].add(cid2 * 10 + did2)
COLORED_MNIST_CSETS_REV = {int_set_to_str(v): k for k, v in COLORED_MNIST_CSETS.items()}


GTSDB_CSETS = {
    "warning_signs": {18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31},
    "blue_signs": {33, 34, 35, 36, 37, 38, 39, 40},
    "restriction_ends": {6, 32, 41, 42},
    "speed_signs": {0, 1, 2, 3, 4, 5, 7, 8},
    "stop+no_entry+warning+construction": {14, 15, 18, 25},
}
GTSDB_CSETS_REV = {int_set_to_str(v): k for k, v in GTSDB_CSETS.items()}


def cset_str_description(dataset_name: str, class_set: set[int], data_path: str=None) -> str:
    default = "+".join([str_labels(dataset_name)[c] for c in class_set])
    if dataset_name == 'coloredmnist':
        return COLORED_MNIST_CSETS_REV.get(int_set_to_str(class_set), default)
    elif dataset_name == 'gtsdb':
        return GTSDB_CSETS_REV.get(int_set_to_str(class_set), default)
    else:
        return default
