import copy
import os
from typing import Iterable

import numpy as np
import pandas as pd
import torch
from braindecode import EEGClassifier
from braindecode.augmentation import AugmentedDataLoader, IdentityTransform
from braindecode.models import ShallowFBCSPNet, SleepStagerChambon2018
from braindecode.util import set_random_seeds
from joblib import Parallel, delayed
from sklearn.metrics import balanced_accuracy_score, confusion_matrix
from sklearn.model_selection import GroupKFold, train_test_split
from sklearn.utils.class_weight import compute_class_weight
from skorch.callbacks import EarlyStopping
from skorch.helper import predefined_split
from torch.utils.data.dataset import Subset

from BAE.utils import (find_device, get_labels, get_subjects, group_split,
                       worker_init_fn)


def get_EEGClassifier(
        dataset_name="SleepPhysionet",
        clf_params=None,
        random_state=None):
    """Generates a braindecode EEGClassifier object. There are many defalut
    parameters defined in this function but they can all be tuned.

    Parameters
    ----------
    model_params: dict, optional
        Neral network parameters.
    clf_params: dict, optional
        Classifier parameters for the braindecode.EEGClassifier.
    random_state: int | None, optional
        Seed or random number generator to use for the generation of a
        sub-training set.

    Returns
    -------
    braindecode.EEGClassifier
    """
    if random_state:
        set_random_seeds(random_state, find_device()[0])

    DEFAULT_MODEL_PARAMS_PHYSIONET = {
        'n_channels': 2,
        'n_classes': 5,
        'sfreq': 100,
        'input_size_s': 30,
        'time_conv_size_s': 0.5,
        'apply_batch_norm': True}
    DEFAULT_MODEL_PARAMS_BCI = {
        'in_chans': 22,
        'n_classes': 4,
        'input_window_samples': 1000,
        'final_conv_length': 'auto',
        'batch_norm': True}
    DEFAULT_CLF_PARAMS = {
        'criterion': torch.nn.CrossEntropyLoss,
        'criterion__weight': None,
        'optimizer': torch.optim.Adam,
        'iterator_train': AugmentedDataLoader,
        'iterator_train__num_workers': 4,
        'iterator_train__worker_init_fn': worker_init_fn,
        'iterator_train__transforms': [
            IdentityTransform()],
        'train_split': None,
        'optimizer__lr': 1e-3,
        'callbacks': [
            'balanced_accuracy',
            ('early stopping',
             EarlyStopping(
                 patience=30))],
        'batch_size': 16,
        'device': find_device()[0]}

    if dataset_name == "SleepPhysionet":
        model = SleepStagerChambon2018(**DEFAULT_MODEL_PARAMS_PHYSIONET)

    elif dataset_name == "BCI":
        model = ShallowFBCSPNet(**DEFAULT_MODEL_PARAMS_BCI)

    if DEFAULT_CLF_PARAMS['device'].type == 'cuda':
        DEFAULT_CLF_PARAMS['iterator_train__pin_memory'] = True
        # Fix for joblib multiprocessing
        DEFAULT_CLF_PARAMS['iterator_train__multiprocessing_context'] = 'fork'

    if clf_params:
        for param_name, param_val in clf_params.items():
            DEFAULT_CLF_PARAMS[param_name] = param_val

    DEFAULT_CLF_PARAMS['module'] = model

    return EEGClassifier(**DEFAULT_CLF_PARAMS)


def fit_and_predict(
        clf,
        train_set,
        test_set,
        valid_set=None,
        epochs=5,
        random_state=None):
    """Train a classifier on a train set and use it to infer over a test set.

    Parameters
    ----------
    clf: braindecode.EEGClassifier
        Classifier used.
    train_set: Dataset
        Dataset used for the training of the model.
    test_set: Dataset
        Dataset used for inference.
    valid_set: Dataset
        Dataset used validation, especially for validation.
    epochs: int, optional
        Number of epochs for the training
    random_state: int | None, optional
        Seed or random number generator to use for the generation of a
        sub-training set.

    Returns
    -------
    dict:
        Prediction on the train, test and valid sets.
    """
    device = find_device()[0]

    if random_state:
        set_random_seeds(random_state, device)

    labels = get_labels(train_set)
    class_weights = compute_class_weight(
        'balanced', classes=np.unique(labels), y=labels)
    clf.set_params(criterion__weight=torch.Tensor(class_weights).to(device))

    # skorch peculiarity: 'train_split' stands for validation split
    clf.set_params(train_split=predefined_split(valid_set))
    clf.fit(train_set, y=None, epochs=epochs)

    predictions = {
        'train': clf.predict(train_set),
        'test': clf.predict(test_set)
    }
    if valid_set:
        predictions['valid'] = clf.predict(valid_set)

    return predictions


def compute_score(train_set, test_set, predictions, valid_set=None):
    """Compute some metrics based on the predictions and ground truth values.

    Parameters
    ----------
    train_set: Dataset
        Dataset containing ground truth values.
    test_set: Dataset
        Dataset containing ground truth values.
    valid_set: Dataset
        Dataset containing ground truth values.
    predictions: dict
        Output of fit_and_predict()

    Returns
    -------
    dict
        Dictionary containing the scores per subset.
    """
    scores = []

    for key, subset in zip(
        ['train', 'test', 'valid'],
            [train_set, test_set, valid_set]):
        y_pred = predictions[key]
        y_true = np.array([subset[i][1] for i in range(len(subset))])

        score = {
            'balanced_accuracy': balanced_accuracy_score(
                y_true=y_true, y_pred=y_pred),
            'confusion_matrix': confusion_matrix(y_true=y_true, y_pred=y_pred),
            'set': key,
            'y_pred': y_pred,
            'y_true': y_true}

        for label in np.unique(y_true):
            y_pred_label = y_pred == label
            y_true_label = y_true == label
            score['balanced_accuracy_class{}'.format(
                label)] = balanced_accuracy_score(
                    y_true=y_true_label, y_pred=y_pred_label)

        scores.append(score)

    return scores


def fit_and_score_one_proportion(dataset,
                                 train_valid_indices,
                                 test_indices,
                                 clf,
                                 proportion=0.5,
                                 epochs=5,
                                 fold=None,
                                 random_state=None,
                                 fold_random_state=None,
                                 aug_magnitude=None,
                                 subjects_mask_train_valid=None):
    """Determines training, validation and test scores for
    a model trained on a subset that consist of randomly sampled windows
    from the train set.

    Parameters
    ----------
    dataset: Dataset
    train_valid_indices: list
        List of indices to use for training and validation.
    test_indices: list
        List of indices to use for testing.
    clf: NeuralNetClassifier
    proportion: float
        Propotion of the dataset to use for the training
    epochs: int
        Number of epochs for the training.
    fold: int
        Fold number that will be logged in the output DataFrame.
    random_state: int | None, optional
        Seed or random number generator to use for the generation of a
        sub-training set.
    fold_random_state: int | None, optional
        Seed or random number generator to use for the initialization of
        the model and the train/validation split. Enables to have trainings
        with the same validation split and the same model initialization in
        order to assess the inflence of the training set size with all other
        things being equal.
    aug_magnitude: float
        Augmentation parameter magnitude. Will add a column in the output to
        specify which magnitude was used. It would be cleaner to use something
        like clf.get_params['iterator_train_transform'].get_params(), though
        the output is not homogeneous between different transforms.
    subjects_mask_train_valid: list
        List that contains the subjects mask. Only useful if the dataset does
        not contains the information concerning subjects (e.g. it is a skorch
        Subset object).

    Returns
    -------
    pd.DataFrame

    Note
    ----
    The choice to parse train_valid indices and a dataset instead of
    train_valid_set may seems weird but is motivated by the fact
    groupKFlod returns indices. Consequently it is easier to parse
    indices in order to make the Parallel(delayed()) line shorter.
    """
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False

    if fold_random_state:
        set_random_seeds(fold_random_state, find_device()[0])
    clf.initialize()

    if not isinstance(subjects_mask_train_valid, Iterable):
        subjects_mask_train_valid = get_subjects(dataset)[train_valid_indices]
    assert(len(train_valid_indices) == len(subjects_mask_train_valid))
    assert(np.all(subjects_mask_train_valid <= 90))

    train_valid_set = Subset(dataset, train_valid_indices)
    test_set = Subset(dataset, test_indices)

    labels_mask = get_labels(train_valid_set)
    scores = []

    train_indices, valid_indices = group_split(
        train_valid_set,
        groups=subjects_mask_train_valid,
        train_size=0.8,
        random_state=fold_random_state)

    if proportion == 1:
        sub_train_indices = train_indices
    else:
        sub_train_indices, _ = train_test_split(
            train_indices,
            train_size=proportion,
            random_state=random_state,
            stratify=labels_mask[train_indices])

    train_subset = Subset(train_valid_set, sub_train_indices)
    valid_set = Subset(train_valid_set, valid_indices)

    classifier_copy = copy.deepcopy(clf)

    predictions = fit_and_predict(
        classifier_copy,
        train_set=train_subset,
        valid_set=valid_set,
        test_set=test_set,
        epochs=epochs,
        random_state=fold_random_state)
    score = compute_score(
        train_set=train_subset,
        valid_set=valid_set,
        test_set=test_set,
        predictions=predictions)
    for x in score:
        aug = clf.get_params()['iterator_train__transforms'][0]
        x.update({'proportion': proportion,
                  'train_windows': len(train_subset),
                  'fold': fold,
                  'sub_train_indices': sub_train_indices,
                  'augmentation': aug,
                  'fold_random_state': fold_random_state,
                  })
        if aug_magnitude:
            x['aug_magnitude'] = aug_magnitude
        scores.append(x)

    return pd.DataFrame(scores).sort_values(['fold'])


def parallel_learning_curve(
        dataset,
        clf,
        K=4,
        proportions=np.logspace(4, 6, 2, base=1 / 2),
        epochs=5,
        n_jobs=1,
        random_state=None,
        subjects_mask=None,
):
    """ Determines cross-validated training, validation and test scores for
    different training set sizes. A cross-validation generator splits the
    whole dataset K times (based on subject ids) in training and test data.
    Afterwards, a classifier is trained on subsets of randomly selected
    windows that contains a given proportion of the train set.

    Parameters
    ----------
    dataset: Dataset
        Must be a windows dataset.
    epochs: int
        Numer of epochs used for each training.
    proportions: list | np.array
        List of the proportions used for different trainings.
    clf: braindecode.EEGClassifier
        Classifier to use.
    k: int
        Number of folds for cross validation.
    random_state: int | None, optional
        Seed or random number generator to use for the generation of a
        sub-training set.
    subjects_mask: np.array
        List that contains the subjects mask. Only useful if the dataset does
        not contains the information concerning subjects (e.g. it is a skorch
        Subset object).

    Returns
    -------
    pd.DataFrame
        Contains the following columns:

        fold: int,
        proportion: float
            Proportion of the training set used.
        set: str
            training, validation, test.
        balanced_accuracy: float
            balanced accuracy score
        Stage2Acc: float
            Accuracy score for sleep stage 2
        confusion_matrix: np.array

    Note
    ----
    The size of the validation set is constant for all the trainings
    (len(dataset)*0.8*(K-1)/K).
    """

    rng = np.random.default_rng(random_state)
    folds_random_state = rng.integers(0, 1e5, size=K)
    props_random_state = rng.integers(0, 1e5, size=(K, len(proportions)))

    indices = np.arange(len(dataset))
    if not isinstance(subjects_mask, Iterable):
        subjects_mask = get_subjects(dataset)
    assert(len(indices) == len(subjects_mask))

    group_kfold = GroupKFold(n_splits=K)
    prop_scores = Parallel(n_jobs=n_jobs)(
        delayed(fit_and_score_one_proportion)(
            dataset,
            train_valid_indices,
            test_indices,
            clf=clf,
            proportion=p,
            epochs=epochs,
            fold=k,
            random_state=props_random_state[k, i],
            fold_random_state=folds_random_state[k],
            subjects_mask_train_valid=subjects_mask[train_valid_indices])
        for k, (train_valid_indices, test_indices) in
        enumerate(group_kfold.split(indices,
                                    groups=subjects_mask),
                  ) for i, p in enumerate(proportions))

    output = pd.concat(prop_scores, axis=0)
    output.index = np.arange(len(output))
    return output
