import argparse
import os
from os.path import join
from numbers import Real
import copy
from collections.abc import Iterable
import yaml

import numpy as np
import pandas as pd
from sklearn.model_selection import GroupKFold
from sklearn.metrics import balanced_accuracy_score, confusion_matrix,\
    log_loss, cohen_kappa_score
from sklearn.utils.class_weight import compute_class_weight,\
    compute_sample_weight
from sklearn.utils import check_random_state
from joblib import Parallel, delayed
import torch
from torch.utils.data import Subset
from skorch.helper import predefined_split
from skorch.callbacks import EpochScoring, Checkpoint, TrainEndCheckpoint,\
    EarlyStopping, LoadInitState
from mne import set_log_level
from mne_bids import BIDSPath, read_raw_bids, get_entity_vals
from braindecode import EEGClassifier
from braindecode.datasets.sleep_physionet import SleepPhysionet
from braindecode.models import SleepStagerChambon2018
from braindecode.datautil.windowers import create_windows_from_events
from braindecode.util import set_random_seeds
from braindecode.datautil.preprocess import zscore, MNEPreproc, NumpyPreproc,\
    preprocess
from braindecode.datautil.mne import create_from_mne_raw
from braindecode.augmentation.base import BaseDataLoader

from eeg_augment.utils import grouped_split, stratified_split, log2_grid,\
    linear_grid, get_groups, flexible_int, flexible_float,\
    ENFORCEABLE_TRANSFORMS, find_device, worker_init_fn
from eeg_augment.sanity_check import make_invariant_function_from_transform,\
    make_real_dataset_invariant, LinearRegression, IIDGaussianDataset,\
    AR1Dataset


TARGETS_MAPPING = {  # We merge stages 3 and 4 following AASM standards.
        'Sleep stage W': 0,
        'Sleep stage 1': 1,
        'Sleep stage 2': 2,
        'Sleep stage 3': 3,
        'Sleep stage 4': 3,
        'Sleep stage R': 4
}

POSSIBLE_GRID_GENERATORS = {
    'log2': log2_grid,
    'lin': linear_grid,
}


class CrossvalModel:
    """Wrapper of skorch NeuralNetClassifier object for K-fold training

    Parameters
    ----------
    training_dir : str
        Path to directory where elements should be saved.
    model : torch.nn.Module
        torch Module modeling a neural net classifier.
    model_params : dict, optional
        Parameters to pass to skorch.classifier.NeuralNetClassifer
        constructor (other than callbacks and splits). By default None.
        See https://skorch.readthedocs.io/en/stable/classifier.html for
        more details.
    n_folds : int, optional
        Number of folds for cross-validation. It will generate n_fold
        triplets (train, valid, test) sets. By default 5.
    shared_callbacks : list, optional
        List of callback objects (torch or skorch) shared across all
        folds (or of tuples (`str`, `Callback`)). Should not be used for
        logging/checkpointing/loading callbacks ; see other args for this.
        By default None.
    balanced_loss : boolean, optional
        Whether to balance the passed loss with the classes frequencies
        from the split training sets. By default True.
    monitor : str, optional
        Metric to use for checkpointing. By default 'valid_loss_best'.
    should_checkpoint : boolean, optional
        Whether to save model checkpoints (including training history) into
        fold-specific folders: training_dir/foldnofN.
        Will save based on improvements in metric self.monitor and when
        training stops. By default True.
    should_load_state : boolean, optional
        Whether to load states from latest checkpoint. Works fold-wise.
        By default True.
    log_tensorboard : boolean, optional
        Whether to log metrics for monitoring on Tensorboard. Will be saved
        in training_dir/foldnofN/logs when set to True. By default True.
    train_size_over_valid : float, optional
        Float between 0 and 1 setting the ratio used to split each training
        fold into a train and a validation set. By default 0.8.
    random_state : int | None, optional
        Used for seeding random number generator.
    """
    def __init__(
        self,
        training_dir,
        model,
        model_params=None,
        n_folds=5,
        shared_callbacks=None,
        balanced_loss=True,
        monitor='valid_loss_best',
        should_checkpoint=True,
        should_load_state=True,
        log_tensorboard=True,
        train_size_over_valid=0.8,
        random_state=None,
        **kwargs
    ):
        print("Warning!! Unused kwargs passed to CrossvalModel init: ", kwargs)
        assert isinstance(n_folds, int) and n_folds > 0,\
            "n_folds should be a positive int"
        self.n_folds = n_folds

        assert isinstance(training_dir, str), "training_dir should be a str"
        os.makedirs(training_dir, exist_ok=True)
        self.training_dir = training_dir

        assert isinstance(model, torch.nn.Module),\
            "model should be a torch.nn.Module object."
        self.model = model

        if model_params is None:
            model_params = {}
        assert isinstance(model_params, dict), "model_params should be a dict."
        self.model_params = model_params

        if shared_callbacks is None:
            shared_callbacks = []
        assert (
                isinstance(shared_callbacks, list) and
                all([isinstance(c, tuple) for c in shared_callbacks]) and
                all([isinstance(c[0], str) for c in shared_callbacks])
        ), (
            "shared_callbacks should be a list of tuples (str, ",
            "torch/skorch callbacks."
        )
        self.shared_callbacks = shared_callbacks

        self.monitor = monitor

        # XXX : you should use Python 3 typing system here to simplify
        # your code
        assert isinstance(should_checkpoint, bool),\
            "should_checkpoint should be a boolean."
        self.should_checkpoint = should_checkpoint

        assert isinstance(should_load_state, bool),\
            "should_load_state should be a boolean."
        self.should_load_state = should_load_state

        assert isinstance(balanced_loss, bool),\
            "balanced_loss should be a boolean."
        self.balanced_loss = balanced_loss

        assert isinstance(log_tensorboard, bool),\
            "log_tensorboard should be a boolean."
        self.log_tensorboard = log_tensorboard

        assert (isinstance(train_size_over_valid, Real) and
                train_size_over_valid >= 0 and train_size_over_valid <= 1),\
            "train_size_over_valid should be a float between 0 and 1."
        self.train_size_over_valid = train_size_over_valid

        # Mainly used for splitting the data
        self.splitting_random_state = random_state

        # Used to generate independent seeds passed to each parallel job
        self.rng = np.random.default_rng(random_state)

        self.child_states = None

    def _fit_and_score(
        self,
        split,
        epochs,
        windows_dataset,
        model_params,
        random_state,
        **kwargs
    ):
        """Train and tests a copy of self.model on the desired split

        Parameters
        ----------
        split : tuple
            Tuple containing the fold index, the training set proportion and
            the indices of the training, validation and test set.
        epochs : int
            Maximum number of epochs for the training.
        windows_dataset : torch.utils.data.Dataset
            Dataset that will be split and used for training, validation and
            tetsing.
        model_params : dict
            Modified copy of self.model_params.
        random_state : int | None
            Seed to use for RNGs.

        Returns
        -------
        dict
            Dictionary containing the balanced accuracy, the loss, the kappa
            score and the confusion matrix for the training, validationa and
            test sets.
        """

        fold, subset_ratio, train_subset_idx, valid_idx, test_idx = split

        set_random_seeds(
            seed=random_state,
            cuda=self.model_params['device'].type == "cuda"
        )

        fold_path = join(self.training_dir, f'fold{fold}of{self.n_folds}')
        subset_path = join(fold_path, f'subset_{subset_ratio}_samples')

        train_subset = Subset(windows_dataset, train_subset_idx)
        test_set = Subset(windows_dataset, test_idx)
        valid_set = Subset(windows_dataset, valid_idx)

        print(
            f"---------- Fold {fold} out of {self.n_folds} |",
            f"Training size: {len(train_subset)} ----------"
        )

        callbacks = list()
        callbacks += self.shared_callbacks

        callbacks += make_training_specific_callbacks(
            subset_path,
            metric_to_monitor=self.monitor,
            should_checkpoint=self.should_checkpoint,
            should_load_state=self.should_load_state
        )

        predicted_probas, class_weights = fit_and_predict(
            model=copy.deepcopy(self.model),
            train_set=train_subset,
            valid_set=valid_set,
            test_set=test_set,
            epochs=epochs,
            model_params=model_params,
            balanced_loss=self.balanced_loss,
            callbacks=callbacks,
            **kwargs
        )

        results_per_subset = {
            'fold': fold, 'n_fold': self.n_folds, 'subset_ratio': subset_ratio,
        }
        keys_and_ds = zip(
            ["train", "valid", "test"],
            [train_subset, valid_set, test_set]
        )
        for key, ds in keys_and_ds:
            # Evaluate metrics on all 3 datasets and add to results returned
            y_true = np.array([y for _, y, _ in ds])
            ds_weights = compute_sample_weight(class_weights, y_true)
            y_proba = predicted_probas[key]
            y_pred = np.argmax(y_proba, axis=1)

            results_per_subset.update({
                f'{key}_bal_acc': balanced_accuracy_score(y_true, y_pred),
                f'{key}_confusion_matrix': confusion_matrix(y_true, y_pred),
                f'{key}_loss': log_loss(y_true, y_proba,
                                        sample_weight=ds_weights),
                f'{key}_cohen_kappa_score': cohen_kappa_score(y_true, y_pred),
            })
        return results_per_subset

    def _make_child_random_states(self, n_jobs):
        # get the SeedSequence of the passed RNG
        seed_sequencer = self.rng.bit_generator._seed_seq

        # create  independent seed sequencers
        child_seed_sequencers = seed_sequencer.spawn(n_jobs)

        # get independent seeds
        return [
            sequencer.generate_state(1)[0]
            for sequencer in child_seed_sequencers
        ]

    def _crossval_apply(
        self,
        windows_dataset,
        epochs,
        function_to_apply,
        data_ratios=None,
        max_ratios=None,
        grouped_subset=True,
        n_jobs=1,
        verbose=False,
        **kwargs
    ):
        """Cross-validation-like application of a desired function taking a
        dataset, a number of epochs and training-validation-test set split
        indices as arguments.

        Parameters
        ----------
        windows_dataset : torch.utils.data.Dataset
            Whole dataset used for training, validation and testing.
        epochs : int
            Number of epochs to pass to the function.
        function_to_apply : callable
            Function applied to each different split. Should accept at least
            the following arguments: `split`, 'random_state`, `epochs`,
            `windows_dataset`.
        data_ratios : float | list | None, optional
            Proportions to use for subsetting the training set. Defaults to
            None.
        max_ratios : [type], optional
            [description], by default None
        grouped_subset : bool, optional
            [description], by default True
        n_jobs : int, optional
            [description], by default 1
        verbose : bool, optional
            [description], by default False

        Returns
        -------
        [type]
            [description]
        """
        # Compute train, valid and test indices triplets for all folds/subsets
        if not hasattr(self, "split_indices") or self.split_indices is None:
            kf = GroupKFold(n_splits=self.n_folds)
            groups = get_groups(windows_dataset)
            # Store those in an attribute so that splits are not recomputed for
            # a given object
            self.split_indices = _get_split_indices(
                cv=kf,
                windows_dataset=windows_dataset,
                groups=groups,
                train_size_over_valid=self.train_size_over_valid,
                data_ratios=data_ratios,
                max_ratios=max_ratios,
                grouped_subset=grouped_subset,
                random_state=self.splitting_random_state,
            )

        # Use the object's rng to generate independ seeds for each parallel job
        if self.child_states is None:
            self.child_states = self._make_child_random_states(
                len(self.split_indices)
            )

        # Loop (in parallel) across each indices triplet and corresponding seed
        # and use them to launch a function (e.g training with a given
        # augmentation, searching augmentations, etc.)
        parallel = Parallel(
            n_jobs=n_jobs,
            verbose=verbose,
            prefer='processes',
        )

        results_per_fold = parallel(delayed(function_to_apply)(
            split=split,
            random_state=seed,
            epochs=epochs,
            windows_dataset=windows_dataset,
            **kwargs
        ) for (split, seed) in zip(self.split_indices, self.child_states))

        return results_per_fold

    def learning_curve(
        self,
        windows_dataset,
        epochs,
        data_ratios=None,
        max_ratios=None,
        grouped_subset=True,
        n_jobs=1,
        verbose=False,
        **kwargs
    ):
        """Fit model to dataset and evaluate cross-validated performance

        Parameters
        ----------
        windows_dataset : torch.data.utils.ConcatDataset
            Dataset to fit.
        epochs : int
            Number of epochs.
        data_ratios : list | float | str | None, optional
            Float or list of floats between 0 and 1 or a str (only "log2" and
            "lin" supported for now). Each element will be used to build a
            subset of the cross-validated training sets (valid and test sets
            are conserved). Omitting it or setting it to None, is equivalent to
            setting it to [1.] (using the whole training set). If "log2" is
            passed, then a log2 scale of training sizes will be used. By
            default None.
        max_ratios : int | None, optional
            Maximum number of subsets to be built. Useful when ratios are
            computed automatically. Ignored when data_ratios is omitted or
            a list.
        grouped_subset : bool, optional
            Whether to compute training subsets taking groups (subjects) into
            account or not. When False, stratified spliting will be used to
            build the subsets. By default True.
        n_jobs : int, optional
            Number of workers to use for parallelizing across splits. By
            default 1.
        verbose : bool, optional
            By default False.
        """
        results_per_fold = self._crossval_apply(
            windows_dataset,
            epochs,
            function_to_apply=self._fit_and_score,
            model_params=self.model_params.copy(),
            data_ratios=data_ratios,
            max_ratios=max_ratios,
            grouped_subset=grouped_subset,
            n_jobs=n_jobs,
            verbose=verbose,
            **kwargs
        )
        results_per_fold = pd.DataFrame(results_per_fold)

        results_per_fold.to_pickle(
            join(self.training_dir, 'test_crossval_results.pkl')
        )
        return results_per_fold


def check_grid(grid, max_value=None, n_values=None):
    if isinstance(grid, str):
        assert max_value is not None, "max_value required by grid generators"
        grid_values = POSSIBLE_GRID_GENERATORS[grid](
            max_value, n_values
        )
    elif isinstance(grid, Iterable):
        grid_values = list(grid)
    elif isinstance(grid, float) or grid is None:
        grid_values = [grid]
    else:
        raise ValueError(
            "grid can be either an iterable or a str.",
            f"Got {type(grid)}."
        )
    return grid_values


def _get_split_indices(
    cv,
    windows_dataset,
    groups,
    train_size_over_valid,
    data_ratios,
    max_ratios,
    grouped_subset=True,
    random_state=None
):
    if data_ratios is None:
        data_ratios = [1.]
    if not grouped_subset:
        targets = np.array([y for _, y, _ in windows_dataset])

    splits_proportions = list()
    for k, fold in enumerate(
        cv.split(windows_dataset, groups=groups),
        start=1
    ):
        train_and_valid_idx, test_idx = fold

        train_idx, valid_idx = grouped_split(
            indices=train_and_valid_idx,
            ratio=train_size_over_valid,
            groups=groups[train_and_valid_idx],
            random_state=random_state
        )
        for ratio in check_grid(data_ratios, len(train_idx), max_ratios):
            if grouped_subset:
                sub_tr_idx, _ = grouped_split(
                    indices=train_idx,
                    ratio=ratio,
                    groups=groups[train_idx],
                    random_state=random_state
                )
            else:
                sub_tr_idx, _ = stratified_split(
                    indices=train_idx,
                    ratio=ratio,
                    targets=targets[train_idx],
                    random_state=random_state
                )
            splits_proportions.append(
                (k, ratio, sub_tr_idx, valid_idx, test_idx)
            )
    return splits_proportions


def fit_and_predict(
    model,
    train_set,
    valid_set,
    test_set,
    epochs,
    model_params=None,
    balanced_loss=True,
    callbacks=None,
    **kwargs
):
    """Train models on train set and use it to predict target probabilities on
    test set

    Parameters
    ----------
    model : torch.nn.Module
        Model to train.
    train_set : torch.util.data.Dataset
        Dataset to train on.
    valid_set : torch.util.data.Dataset
        Validation set.
    test_set : torch.util.data.Dataset
        Dataset to use to evaluate the trained model.
    epochs : int,
        Number of training epochs.
    model_params : dict | None, optional
        Parameters to pass to skorch.EEGClassifier class. Defaults to None.
    balanced_loss : boolean, optional
        Whether to balance the passed loss with the classes frequencies
        from the training sets. By default True.
    callbacks : list, optional
        List of of skorch or pytorch callbacks to use. Defaults to None.
    """
    if model_params is None:
        model_params = {}

    classifier_params = {
        'module': model,
        'train_split': predefined_split(valid_set),
        'callbacks': callbacks,
    }
    classifier_params.update(model_params)

    class_weights_dict = compute_class_weights_dict(train_set, balanced_loss)
    if class_weights_dict is not None:
        classifier_params['criterion__weight'] = torch.Tensor(
            list(class_weights_dict.values())
        ).to(model_params.get('device', 'cpu'))

    clf = EEGClassifier(**classifier_params)

    # Model training for a specified number of epochs. `y` is None as
    # it is already supplied in the dataset.
    clf.fit(train_set, y=None, epochs=epochs, **kwargs)
    predicted_probas = {
        "train": clf.predict_proba(train_set),
        "valid": clf.predict_proba(valid_set),
        "test": clf.predict_proba(test_set),
    }

    return predicted_probas, class_weights_dict


def compute_class_weights_dict(train_set, balanced_loss):
    class_weights_dict = None
    if balanced_loss:
        train_y = np.array([y for _, y, _ in train_set])
        class_weights = compute_class_weight(
            'balanced',
            classes=np.unique(train_y),
            y=train_y
        )
        class_weights_dict = {i: w for i, w in enumerate(class_weights)}
    return class_weights_dict


def make_class_proportions_tensor(dataset, balanced_loss, device):
    class_weights = None
    if balanced_loss:
        class_weights_dict = compute_class_weights_dict(
            dataset, balanced_loss)
        class_weights = torch.as_tensor(
            list(class_weights_dict.values()),
            device=device, dtype=torch.float,
        )
    return class_weights


def make_training_specific_callbacks(
    logging_path,
    metric_to_monitor='valid_loss_best',
    should_checkpoint=True,
    should_load_state=False
):
    callbacks = list()
    if should_checkpoint or should_load_state:
        checkpointing = Checkpoint(
            dirname=logging_path,
            monitor=metric_to_monitor
        )
        final_checkpointing = TrainEndCheckpoint(dirname=logging_path)
        callbacks += [
            ('checkpoint', checkpointing),
            ('end_checkpoint', final_checkpointing)
        ]

    if should_load_state:
        load_state = LoadInitState(checkpointing)
        callbacks.append(('load_state', load_state))
    return callbacks


def make_sleep_stager_model(
    windows_dataset,
    device,
    sfreq,
    n_classes=5,
    parallel=False
):
    """Creates a new SleepStagerChambon2018 model with correct settings

    Parameters
    ----------
    windows_dataset : torch.data.utils.ConcatDataset
        Dataset which will be used for training the model. Will be used to
        extract a few parameters.
    device : str | torch.device
        Device to put model on.
    sfreq : int
        Sampling frequence in Hz.
    n_classes : int, optional
        Number of classes to predict. By default 5
    parallel : bool, optional
        Whether to parallelize the model across GPUs using
        torch.nn.DataParallel. Will only be done if more than one GPU is
        available. By default False.

    Returns
    -------
    torch.nn.Module
    """
    # Extract number of channels and time steps from dataset
    n_channels, input_size_samples = windows_dataset[0][0].shape

    model = SleepStagerChambon2018(
        n_channels,
        sfreq,
        n_classes=n_classes,
        input_size_s=input_size_samples / sfreq
    )
    if parallel and torch.cuda.device_count() > 1:  # NOT TESTED
        print("Using ", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)
    return model.to(device)


def make_linear_model(
    dataset,
    device,
    n_classes=4,
):
    """Creates a new linear regression model with correct settings

    Parameters
    ----------
    dataset : torch.data.utils.Dataset
        Dataset which will be used for training the model. Will be used to
        extract a few parameters.
    device : str | torch.device
        Device to put model on.
    n_classes : int, optional
        Number of classes to predict. By default 4.

    Returns
    -------
    torch.nn.Module
    """
    model = LinearRegression(
        input_size=np.prod(dataset[0][0].shape),
        output_size=n_classes,
    )
    return model.to(device)


def prep_physionet_dataset(
    mne_data_path=None,
    n_subj=None,
    recording_ids=None,
    window_size_s=30,
    sfreq=100,
    should_preprocess=True,
    should_normalize=True,
    high_cut_hz=30,
    crop_wake_mins=30,
    crop=None,
    preload=False,
):
    """Import, create and preprocess SleepPhysionet dataset.

    Parameters
    ----------
    mne_data_path : str, optional
        Path to put the fetched data in. By default None
    n_subj : int | None, optional
        Number of subjects to import. If omitted, all subjects will be imported
        and used.
    recording_ids : list | None, optional
        List of recoding indices (int) to be imported per subject. If ommited,
        all recordings will be imported and used (i.e. [1,2]).
    window_size_s : int, optional
        Window size in seconds defining each sample. By default 30.
    sfreq : int, optional
        Sampling frequency in Hz. by default 100
    should_preprocess : bool, optional
        Whether to preprocess the data with a low-pass filter and microvolts
        scaling. By default True.
    should_normalize : bool, optional
        Whether to normalize (zscore) the windows. By default True.
    high_cut_hz : int, optional
        Cut frequency to use for low-pass filter in case of preprocessing. By
        default 30.
    crop_wake_mins : int, optional
        Number of minutes of wake time to keep before the first sleep event
        and after the last sleep event. Used to reduce the imbalance in this
        dataset. Default of 30 mins.
    crop : tuple | None
        If not None, crop the raw data with (tmin, tmax). Useful for
        testing fast.
    preload : bool, optional
        Whether to preload raw signals in the RAM.

    Returns
    -------
    braindecode.datasets.BaseConcatDataset
    """

    if n_subj is None:
        subject_ids = None
    else:
        subject_ids = range(n_subj)

    dataset = SleepPhysionet(
        subject_ids=subject_ids,
        recording_ids=recording_ids,
        crop_wake_mins=crop_wake_mins,
        path=mne_data_path
    )

    set_log_level(False)

    preprocessors = [
        # convert from volt to microvolt, directly modifying the array
        NumpyPreproc(fn=lambda x: x * 1e6),
        # bandpass filter
        MNEPreproc(
            fn='filter',
            l_freq=None,
            h_freq=high_cut_hz,
            verbose=False
        ),
    ]

    if crop is not None:
        preprocessors.insert(
            1,
            MNEPreproc(
                fn='crop',
                tmin=crop[0],
                tmax=crop[1]
            )
        )

    if should_preprocess:
        # Transform the data
        preprocess(dataset, preprocessors, bar=True)

    window_size_samples = window_size_s * sfreq

    windows_dataset = create_windows_from_events(
        dataset, trial_start_offset_samples=0, trial_stop_offset_samples=0,
        window_size_samples=window_size_samples,
        window_stride_samples=window_size_samples, preload=preload,
        mapping=TARGETS_MAPPING, verbose=False,
    )

    if should_normalize:
        preprocess(windows_dataset, [MNEPreproc(fn=zscore)])
    return windows_dataset, ['Fpz', 'Pz'], 100


def load_preproc_bids(
    sfreq=None,
    bids_root=None,
    n_subj=None,
    window_size_s=30,
    should_normalize=True,
    multisession=False,
    preload=False,
):

    subjects_to_load = get_entity_vals(bids_root, 'subject')
    if n_subj is not None and n_subj > 0 and n_subj < len(subjects_to_load):
        subjects_to_load = subjects_to_load[:n_subj]

    list_of_raws = list()
    for subject in subjects_to_load:
        bids_path = BIDSPath(
            subject=subject,
            root=bids_root,
            suffix='eeg',
            datatype='eeg'
        )
        if multisession:
            sessions = os.listdir(f'{bids_path.root}/sub-{bids_path.subject}')
            for session in sessions:
                bids_path.update(session=session[4:])
                raw = read_raw_bids(bids_path, verbose=False)
                list_of_raws.append(raw)
        else:
            raw = read_raw_bids(bids_path, verbose=False)
            list_of_raws.append(raw)

    if sfreq is None:
        sfreq = list_of_raws[0].info['sfreq']
    window_size_samples = window_size_s * sfreq

    windows_dataset = create_from_mne_raw(
        list_of_raws,
        trial_start_offset_samples=0,
        trial_stop_offset_samples=0,
        window_size_samples=window_size_samples,
        window_stride_samples=window_size_samples,
        preload=preload,
        drop_last_window=False,
        mapping=TARGETS_MAPPING
    )

    if should_normalize:
        preprocess(windows_dataset, [MNEPreproc(fn=zscore)])
    ch_names = [channel.split("-")[0]
                for channel in list_of_raws[0].info['ch_names']]
    return windows_dataset, ch_names, sfreq


def load_preproc_mass(
    bids_root=None,
    **kwargs
):
    sfreq = 128
    if bids_root is None:
        bids_root = os.path.expanduser("~/mne_data/mass-bids/6channels-128Hz/")

    return load_preproc_bids(bids_root=bids_root, sfreq=sfreq, **kwargs)


def load_preproc_physionet(
    bids_root=None,
    **kwargs
):
    sfreq = 100
    if bids_root is None:
        bids_root = os.path.expanduser("~/mne_data/physionet-sleep-edf-bids")

    return load_preproc_bids(
        bids_root=bids_root,
        sfreq=sfreq,
        multisession=True,
        **kwargs
    )


def sample_invariant_decision_function(
    input_size,
    enforce_inv,
    n_classes=5,
    enforce_inv_std=0,
    random_state=None,
):
    """Sample linear coefficients from uniform law and create decision function
    invariant to desired transform

    Parameters
    ----------
    input_size : str
        Length of input sampels.
    enforce_inv : str
        Invariance to enforce (transform)
    n_classes : int, optional
        Size of output. By default 5.
    enforce_inv_std : float, optional
        Std of white noise to add to the decision function. By default 0 (no
        noise).
    random_state : int | numpy.random.Generator | None, optional
        Seed or random number generator to use to sample. By default None.

    Returns
    -------
    callable
    """
    assert (
        isinstance(enforce_inv, str) and
        enforce_inv in ENFORCEABLE_TRANSFORMS
    ),\
        "Possible enforce_inv are " +\
        str(ENFORCEABLE_TRANSFORMS.keys()) + f". Got {enforce_inv}."
    rng = check_random_state(random_state)
    random_linear_map = 10 * rng.random((n_classes, input_size))
    random_bias = 10 * rng.random(n_classes)
    transform, order = ENFORCEABLE_TRANSFORMS[enforce_inv]

    return make_invariant_function_from_transform(
        transform,
        order,
        random_linear_map,
        random_bias,
        enforce_inv_std
    )


def make_shared_callbacks(balanced_acc=True, early_stop=True):
    shared_callbacks = list()

    if balanced_acc:
        train_bal_acc = EpochScoring(
            scoring='balanced_accuracy', on_train=True, name='train_bal_acc',
            lower_is_better=False)
        shared_callbacks.append(('train_bal_acc', train_bal_acc))

        valid_bal_acc = EpochScoring(
            scoring='balanced_accuracy', on_train=False, name='valid_bal_acc',
            lower_is_better=False)
        shared_callbacks.append(('valid_bal_acc', valid_bal_acc))

    if early_stop:
        shared_callbacks.append(
            ('early_stopping', EarlyStopping(
                monitor='valid_bal_acc',
                patience=30,
                lower_is_better=False,
            ))
        )
    return shared_callbacks


def make_vanilla_model_params(lr, batch_size, num_workers, device):
    model_params = {
        'criterion': torch.nn.CrossEntropyLoss,  # Not settable for now
        'optimizer': torch.optim.Adam,  # Not settable for now
        'lr': lr,
        'batch_size': batch_size,
        'iterator_train': BaseDataLoader,
        'iterator_train__num_workers': num_workers,
        'device': device,
    }

    if num_workers > 0:
        model_params['iterator_train__worker_init_fn'] = worker_init_fn
        model_params['iterator_train__multiprocessing_context'] = 'fork'
    if model_params['device'].type == 'cuda':
        model_params['iterator_train__pin_memory'] = True
    return model_params


def prepare_training(
    windows_dataset,
    lr,
    batch_size,
    num_workers,
    early_stop,
    sfreq,
    n_classes,
    parallel=False,
    device=None,
    model_to_use=None,
    random_state=None,
):
    """Do all preliminary work before training, i.e. instantiate pytorch model,
    and random number generator, find appropriate device, skorch classifier
    parameters and callbacks

    Parameters
    ----------
    windows_dataset : torch.util.data.Dataset
        Dataset to use for evaluation.
    lr : float
        Learning rate.
    batch_size : int
        Batch size.
    num_workers : int
        Number of workers used for data loading.
    early_stop : bool
        Whether to carry earlystopping during training.
    sfreq : float
        Sampling frequency.
    n_classes : int
        Number of classes.
    parallel : bool, optional
        Whether to parallelize on several GPUs, by default False
    device : str, optional
        Device to train on, by default None
    model_to_use : str | None, optional
        Type of model to use, by default None
    random_state : int | numpy.random.RandomState | None, optional
        Used to seed random number generator, by default None
    """
    device, cuda = find_device(device)

    # Set random seed broadly to be able to reproduce results
    if random_state is not None:
        set_random_seeds(seed=random_state, cuda=cuda)
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.benchmark = False

    # Also instantiate numpy RNG for proper seeding when possible
    rng = check_random_state(random_state)

    if model_to_use is None:
        model = make_sleep_stager_model(
            windows_dataset,
            device=device,
            sfreq=sfreq,
            n_classes=n_classes,
            parallel=parallel,
        )
    elif model_to_use == 'lin':
        model = make_linear_model(
            windows_dataset,
            device=device,
            n_classes=n_classes,
        )
    else:
        raise ValueError(
            "Unsupported value for model_to_use. Can be either None or 'lin'."
        )

    model_params = make_vanilla_model_params(
        lr=lr,
        batch_size=batch_size,
        num_workers=num_workers,
        device=device,
    )

    shared_callbacks = make_shared_callbacks(early_stop=early_stop)
    return device, rng, model, model_params, shared_callbacks


def make_args_parser():
    parser = argparse.ArgumentParser(
        description='Train a model with cross-validation.'
    )

    parser.add_argument(
        'training_dir',
        type=str,
        help='Directory where training elements will be saved.'
    )

    parser.add_argument(
        '-e', '--epochs',
        default=300,
        type=int,
        help='Number of epochs for the training'
    )

    parser.add_argument(
        '-b', '--batch_size',
        default=16,
        type=int,
        help='Batch size'
    )

    parser.add_argument(
        '--lr',
        default=1e-3,
        type=float,
        help='(Initial) learning rate'
    )

    parser.add_argument(
        '-k', '--nfolds',
        default=5,
        type=int,
        help='Number of folds'
    )

    parser.add_argument(
        '--num_workers',
        default=4,
        type=int,
        help='Number of workers used for data loading'
    )

    parser.add_argument(
        '-d', '--device',
        type=str,
        help='Device to use.'
    )

    parser.add_argument(
        '-a', '--augment',
        action='append',
        type=str,
        help='Augmentations to use.'
    )

    parser.add_argument(
        '-c', '--compose',
        action='store_true',
        help=(
            'When using this option, augmentations passed will be applied',
            ' sequentially.'
        )
    )

    parser.add_argument(
        '--reset_model',
        action='store_false',
        dest='should_load_state',
        help="Whether to reset models on all folds (won't load checkpoints)."
    )

    parser.add_argument(
        '--no_early_stop',
        action='store_false',
        dest='early_stop',
        help="When this argument is passed, no early stopping will be used."
    )

    parser.add_argument(
        '--tv_split_ratio',
        dest='train_size_over_valid',
        type=float,
        default=0.8,
        help='Ratio of subjects used for training, compared to validation.'
    )

    parser.add_argument(
        '--mne_data_path',
        type=str,
        help='Path to where MNE data is stored once it is downloaded.'
    )

    parser.add_argument(
        '--n_subj',
        type=int,
        help='Number of subject recordings to use.'
    )

    parser.add_argument(
        '--rec_ids',
        action='append',
        type=int,
        help='Recordings to use (only works for Physionet dataset).'
    )

    parser.add_argument(
        '--random_state', '--rng',
        type=int,
        help='Used for seeding random generators.'
    )

    parser.add_argument(
        '--enforce_inv',
        type=str,
        help='When passed, used to create fake targets and enforce ' +
             'transform invariance.'
    )

    parser.add_argument(
        '--enforce_inv_std',
        type=float,
        default=0.,
        help='Standard deviation used to add iid Gaussian noise to fake ' +
             'targets.'
    )

    parser.add_argument(
        '--model',
        type=str,
        help='Defines what net to use. By default will be (SleepStager). ' +
             'If set to lin, will use a small linear model.'
    )

    parser.add_argument(
        '--dummy_data',
        type=str,
        help='Defines whteher to use dummy data and which data generator to ' +
             'use. Can be either None (use real data), white or ar1.'
    )

    parser.add_argument(
        '--n_samples',
        type=int,
        default=1200,
        help='Number of dummy data samples (used only for dummy data).'
    )

    parser.add_argument(
        '--data_ratio',
        type=flexible_float,
        help='Whether and how to carry variable data regime training.'
    )

    parser.add_argument(
        '--max_ratios',
        type=int,
        help='Max number of subsets created when data_ratio passed.'
    )

    parser.add_argument(
        '-p', '--proba',
        type=flexible_float,
        default=0.5,
        help='Probability of augmenting. Can be a string or float in [0,1].'
    )

    parser.add_argument(
        '--n_probas',
        type=flexible_int,
        help='Number of probability values tried when proba is a string.'
    )

    parser.add_argument(
        '-m', '--mag',
        type=flexible_float,
        help='Augmentations magnitudes. Can be a string or float in [0,1].'
    )

    parser.add_argument(
        '--n_mags',
        type=flexible_int,
        help='Number of magnitude values tried when mag is a string.'
    )

    parser.add_argument(
        "--ungrouped_subset",
        action="store_false",
        dest='grouped_subset',
        help="When this argument is passed, training subsets will be computed"
             "without taking groups into account (in a stratified manner)."
    )

    parser.add_argument(
        "--n_jobs",
        type=int,
        default=1,
        help="Number of parallel trainings."
    )

    parser.add_argument(
        "--dataset",
        default="edf",
        help="Dataset to use. Can be either MASS ('mass') or Sleep Physionet"
             "EDF ('edf')."
    )

    parser.add_argument(
        "--preload", action='store_true',
        help="Whether to preload data to the RAM"
    )

    parser.add_argument(
        "--config",
        help="Where to get config file, if desired."
    )
    return parser


def read_config(config_path):
    if not os.path.isfile(config_path):
        raise ValueError(f"Could not find config file under {config_path}")
    with open(config_path, "r") as f:
        return yaml.safe_load(f)


def _load_data(dataset, n_subj, rec_ids, mne_data_path, preload):
    if dataset == "edf":
        # windows_dataset, ch_names, sfreq = load_preproc_physionet(
        #    bids_root=args.mne_data_path,
        return prep_physionet_dataset(
            mne_data_path=mne_data_path,
            n_subj=n_subj,
            recording_ids=rec_ids,
            preload=preload,
        )
    elif dataset == "mass":
        return load_preproc_mass(
            bids_root=mne_data_path,
            n_subj=n_subj,
            preload=preload,
        )
    else:
        raise ValueError(
            "Possible values for `dataset` are 'edf'",
            f"(default) or 'mass'. Got {dataset}"
        )


def handle_dataset_args(args):
    ch_names = None
    sfreq = None
    if args.dummy_data is None:
        if args.config:
            config = read_config(args.config)
            data_config = config["data"]
            windows_dataset, ch_names, sfreq = _load_data(
                data_config["dataset"],
                data_config["n_subj"],
                data_config.get("rec_ids", None),
                args.mne_data_path,
                args.preload,
            )
        else:
            windows_dataset, ch_names, sfreq = _load_data(
                args.dataset,
                args.n_subj,
                args.rec_ids,
                args.mne_data_path,
                args.preload,
            )
        if args.enforce_inv is not None:
            invariant_decision_func = sample_invariant_decision_function(
                input_size=6000,
                enforce_inv=args.enforce_inv,
                enforce_inv_std=args.enforce_inv_std,
                random_state=args.random_state
            )

            windows_dataset = make_real_dataset_invariant(
                windows_dataset,
                invariant_decision_func,
            )
    else:
        assert args.enforce_inv is not None,\
            "enforce_inv needed when usign dummy data."
        invariant_decision_func = sample_invariant_decision_function(
            input_size=20,
            enforce_inv=args.enforce_inv,
            enforce_inv_std=args.enforce_inv_std,
            random_state=args.random_state
        )
        if args.dummy_data == 'white':
            windows_dataset = IIDGaussianDataset(
                (2, 10),
                random_state=args.random_state,
                n_samples=args.n_samples,
                invariant_op=invariant_decision_func,
            )
        elif args.dummy_data == 'ar1':
            windows_dataset = AR1Dataset(
                (2, 10),
                n_samples=args.n_samples,
                invariant_op=invariant_decision_func,
                ar1_coefs=(-0.33, 0.4),
                offsets=(0., 0.),
                noise_std=(1.8, 0.2)
            )
        else:
            raise ValueError(
                "Got unexpected value for dummy_data. " +
                "Can be either white or ar1."
            )
    return windows_dataset, ch_names, sfreq
