import abc
import collections.abc
import functools
import random
from typing import Tuple, Union, Optional, Callable, Type, Any, Dict, List

import torch
from torch._six import string_classes
from torch.utils.data import Subset
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format

from utils.utils import split_list


# def cv_split(self, k: int):
#     '''Generates cross validation splits of the training set.
#
#     :param k: Number of folds.
#     :type k: int
#     :return: Anomaly detection dataset for each fold.
#     :rtype: AnomalyDetectionDataset
#     '''
#
#     size    = len(self.train_set)
#     indices = list(range(size))
#
#     random.shuffle(indices)
#
#     folds = split_list(indices, k)
#
#     for i in range(k):
#
#         dataset           = self.create_dataset()
#         dataset.train_set = Subset(self.train_set, [idx for idx in range(size) if idx not in folds[i]])
#         dataset.test_set  = Subset(self.train_set, folds[i])
#
#         yield dataset


class DatasetMixin(abc.ABC):
    @abc.abstractmethod
    def __len__(self) -> int:
        """
        This should return the number of independent time series in the dataset
        """
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def seq_len(self) -> Union[int, List[int]]:
        """
        This should return the length of each time series. If the time series have different lengths, the return
        value should be a list that contains the length of each sequence. If all sequences are of equal length,
        this should return an int.
        """
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def num_features(self) -> Union[int, Tuple[int, ...]]:
        """
        Number of features of each datapoint. This can also be a tuple if the data has more than one feature dimension.
        """
        raise NotImplementedError

    @staticmethod
    @abc.abstractmethod
    def get_default_pipeline() -> Dict[str, Dict[str, Any]]:
        """
        Return the default pipeline for this dataset that is used if the user does not specify a different pipeline.
        This must be a dict of the form
        {
            '<name>': {'class': '<name-of-transform-class>', 'args': {'<args-for-constructor>', ...}},
            ...
        }
        :return:
        """
        raise NotImplementedError

    @staticmethod
    @abc.abstractmethod
    def get_feature_names() -> List[str]:
        """
        Return names for the features in the order they are present in the data tensors.

        :return: A list of strings with names for each feature.
        """
        raise NotImplementedError


def __default_collate(batch, batch_dim=0):
    """
    Puts each data field into a tensor with outer dimension batch size

    Copied from PyTorch
    """

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, dim=batch_dim, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return __default_collate([torch.as_tensor(b) for b in batch], batch_dim=batch_dim)
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: __default_collate([d[key] for d in batch], batch_dim=batch_dim) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(__default_collate(samples, batch_dim=batch_dim) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [__default_collate(samples, batch_dim=batch_dim) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))


def collate_fn(batch_dim: int) -> Callable:
    return functools.partial(__default_collate, batch_dim=batch_dim)
