from copy import Error

import numpy as np
import torch
from braindecode.augmentation import Transform
from braindecode.datasets import (BaseConcatDataset, MOABBDataset,
                                  SleepPhysionet)
from braindecode.datautil.preprocess import zscore
from braindecode.preprocessing import create_windows_from_events
from braindecode.preprocessing.preprocess import Preprocessor, preprocess
from braindecode.util import set_random_seeds
from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data.dataset import Subset


class EEGDataset(BaseConcatDataset):

    """
    Subclass of BaseConcatDataset. Implements methods that cover most of the
    preprocessing pipeline. e.g. applying preprocessings such as
    standardization, extracting windows from the datastet.

    Parameters
    ----------
    dataset: BaseConcatDataset

    """

    def __init__(
            self,
            dataset):
        super().__init__(dataset.datasets)

    def __len__(self):
        return super().__len__()

    def preprocess(self, preprocessors=[
            Preprocessor(lambda x: x * 1e6),
            Preprocessor('filter', l_freq=None, h_freq=30),
            Preprocessor(fn=zscore)]):
        """Preprocess the dataset using braindecode.preprocessing.preprocess
        The default preprocessing consist in converting the voltage of EEG
        channels to µV, low pass filtering the channel with a cut-off frequency
        of 30Hz and standartizing the channels

        Parameters
        ----------
        preprocessors: list
            list of preprocess operation applied to the dataset. Must be
            instances of braindecode.preprocessing.preprocess.Preprocessor
        """
        preprocess(self, preprocessors)

    def get_windows(self,
                    mapping=None,
                    window_size_samples=None,
                    window_stride_samples=None,
                    trial_start_offset_samples=0,
                    trial_stop_offset_samples=0,
                    preload=True,
                    n_jobs=1,):
        """ Extract the epochs from the EEGDataset using
        braindecode.preprocessing.create_windows_from_events

        Parameters
        ----------
        mapping: dict
            Dictionary that maps comprehensible labels to numbers.
            e.g. mapping = {'Sleep stage W': 0}
        window_size_samples: int
            Number of samples per window. Mandatory if the dataset does not
            have a stim channel.
        window_stride_samples: int
            Stride use to exctract windows from the recordings. Mandatory if
            the dataset does not have a stim channel.
        trial_start_offset_samples: int
            Start offset from original trial onsets, in samples. Defaults to
            zero.
        trial_stop_offset_samples: int
            Stop offset from original trial stop, in samples. Defaults to zero.
        preload: bool
            If True, preload the data of the Epochs objects. This is useful to
            reduce disk reading overhead when returning windows in a training
            scenario, however very large data might not fit into memory.
        n_jobs: int
            Number of jobs to use to parallelize the windowing.

        Note:
        -----
        The mapping can be found using `windows.datasets[0].windows.event_id`
        (not that intuitive to find)
        """

        return create_windows_from_events(
            self,
            mapping=mapping,
            window_size_samples=window_size_samples,
            window_stride_samples=window_stride_samples,
            trial_start_offset_samples=trial_start_offset_samples,
            trial_stop_offset_samples=trial_stop_offset_samples,
            preload=preload,
            n_jobs=n_jobs,)


def find_device(device=None):
    """Determines the device tat can be used for computation.

    Parameters
    ----------
    device: str
        Name of the device to use.

    Returns
    -------
    str:
        The device that will be used for computation.
    bool:
        Whether GPU compatible with CUDA is available.
    """
    if device is not None:
        assert isinstance(device, str), "device should be a str."
        return torch.device(device), False
    cuda = torch.cuda.is_available()  # check if GPU is available
    if cuda:
        torch.backends.cudnn.benchmark = True
        if torch.cuda.device_count() > 1:
            device = torch.device('cuda:1')
        else:
            device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    return device, cuda


def get_subjects(dataset):
    """Originally named get_groups in eeg_augment

    Returns
    -------
    np.array
        Array of shape (1, len(dataset)) that maps XXX???

    Note
    ----
    Also works with windows !
    """
    if (hasattr(dataset, "description")
        and hasattr(dataset, "datasets")
            and "subject" in dataset.description):
        return np.hstack([
            [subj] * len(dataset.datasets[rec])
            for rec, subj in enumerate(
                dataset.description['subject'].values)])
    else:
        return np.arange(len(dataset))


def get_labels(windows):
    """Extract the labels from a windows dataset.

    Parameters
    ----------
    windows: braindecode.dataset.WindowsDataset.

    Return
    ------
    np.array
        Array of shape (len(windows),) that contains the label for each window.
    """
    return np.array([windows[i][1] for i in range(len(windows))])


def group_split(dataset, groups, train_size=0.8, random_state=None):
    """Single grouped train/test split.

    Parameters
    ----------
    dataset
        Any dataset compatible with sklearn split functions.
    groups: np.array
        Group labels for the samples used while splitting the dataset into
        train/test set.
    random_state
        Controls the randomness of the training and testing indices produced.
    groups: np.array
        Group labels for the samples used while splitting the dataset into
        train/test set.
    random_state
        Controls the randomness of the training and testing indices produced.
    groups: np.array
        Group labels for the samples used while splitting the dataset into
        train/test set.
    random_state
        Controls the randomness of the training and testing indices produced.
    train_size: float

    Return
    ------
    list
        List of train indices.
    list
        List of test indices.
    """
    indices = np.arange(len(dataset))

    G_split = GroupShuffleSplit(
        n_splits=1,
        train_size=train_size,
        random_state=random_state
    )
    train_indices, test_indices = next(
        iter(G_split.split(X=indices, y=None, groups=groups)))
    return train_indices, test_indices


def worker_init_fn(worker_id):
    seed = np.random.get_state()[1][0]
    set_random_seeds(seed + worker_id, find_device()[0])


def downsample(dataset, random_state=None):
    """Dowsample a dataset so that all classes are balanced.

    Parameters
    ----------
    dataset:
        Dataset object that will be downsampled.
    random_state:
        Controls the randomness of the training and testing indices produced.

    Returns
    -------
    skorch.dataset:
        Downsampled subset.

    list:
        Downsampled subjects mask. Useful since the Subset object does not
        contains the subject info (contrary to braindecode.WindowsDataset)
    """
    if random_state:
        np.random.seed(random_state)

    y = get_labels(dataset)
    scarce_class_count = np.bincount(y).min()
    indices = np.array([], dtype=int)
    downsampled_subjects_mask = np.array([], dtype=int)
    subjects_mask = get_subjects(dataset)

    for i in range(len(np.unique(y))):
        y_i = np.where(y == i)[0]  # y_i is a list of indices (misleading name)
        np.random.shuffle(y_i)
        indices = np.hstack((indices, y_i[:scarce_class_count]))
        downsampled_subjects_mask = np.hstack(
            (downsampled_subjects_mask,
             subjects_mask[y_i[:scarce_class_count]]))
    return Subset(dataset, indices), downsampled_subjects_mask


def get_dataset(name="SleepPhysionet", n_subjects=2, n_jobs=1):
    """
    Preprocessing pipeline in one single function to make the rest of the code
    more concise. If the user wants to add a new dataset with a specific
    preprocessing pipeline, it can be done in this function.

    Parameters:
    -----------
    name: str
        Name of the dataset that will be used as an argparse argument in
        make_learning_curve.py.
    n_subjects: int
        Number of subjects to extract from the dataset.
    n_jobs:
        Number of workers for the parallelisation of the windowing.

    Returns:
    --------
    windows: BaseConcatDataset:
        Preprocessed windows, ready for the training !

    Note
    ----
    The Sleep EDF contains 78 subjects but their ideas are not evenly spaced
    in [0, 77] (it vould be too easy). Instead, indices range from 0 to 82 with
    [39, 68, 69, 78, 79] missing. This dataset implementation makes up for this
    peculiarity. Though this fact must be acknowledged if you are looking for
    a specific subject ID.
    """

    if name == "SleepPhysionet":
        SUBJECT_IDS = np.delete(np.arange(83), [39, 68, 69, 78, 79])

        dataset = EEGDataset(
            SleepPhysionet(
                subject_ids=SUBJECT_IDS[:n_subjects],
                recording_ids=None,
                preload=True,
                load_eeg_only=True))
        # Preprocessing
        preprocessors = [
            Preprocessor(lambda x: x * 1e6),
            Preprocessor('filter', l_freq=None, h_freq=30),
            Preprocessor(zscore)]
        preprocess(dataset, preprocessors)

        return dataset.get_windows(
            mapping={  # We merge stages 3 and 4 following AASM standards.
                'Sleep stage W': 0,
                'Sleep stage 1': 1,
                'Sleep stage 2': 2,
                'Sleep stage 3': 3,
                'Sleep stage 4': 3,
                'Sleep stage R': 4},
            window_size_samples=3000,
            window_stride_samples=3000,
            preload=True,
            n_jobs=n_jobs)

    elif name == "BCI":
        dataset = EEGDataset(
            MOABBDataset(
                dataset_name="BNCI2014001",
                # Subjects are indexed from 1 to 10
                subject_ids=list(np.arange(n_subjects) + 1)))

        preprocessors = [
            Preprocessor('pick_types', eeg=True, meg=False, stim=False),
            Preprocessor(lambda x: x * 1e6),
            Preprocessor('filter', l_freq=4, h_freq=38),
            Preprocessor(zscore)]
        preprocess(dataset, preprocessors)

        return dataset.get_windows(
            mapping={'feet': 0,
                     'left_hand': 1,
                     'right_hand': 2,
                     'tongue': 3},
            preload=True,
            n_jobs=n_jobs)

    raise Exception('Dataset not found')


class ClassWiseAugmentation(Transform):
    """Subclass from Transform. Allows create handcrafted augmentations that
    apply different augmentation to each class.

    Parameters
    ----------
    aug_per_class: dict
        Dictionary that has classes as keys and augmentations as values.
    """

    def __init__(self, aug_per_class):
        self.aug_per_class = aug_per_class
        super().__init__()

    def __repr__(self):
        return str(self.aug_per_class)

    def forward(self, X, y):
        tr_X = X.clone()
        # Could be changed to the apply identity transformation to classes for
        # which no augmentation is specified
        for c in np.unique(y):
            if c not in self.aug_per_class.keys():
                raise Error(
                    "Unknown class {} found.\n Make sure to parse a class-wise"
                    "augmentation dict that covers all the classes.".format(c))

        for c in self.aug_per_class.keys():
            mask = y == c
            if any(mask):
                tr_X[mask, ...], _ = self.aug_per_class[c](
                    X[mask, ...], y[mask])
        return tr_X, y
