from inspect import signature
from numbers import Real
from copy import deepcopy
from itertools import product
import os
from os.path import join, basename
from random import random
import sqlite3
from sqlite3 import Error
from inspect import isclass
from time import time
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from sklearn.metrics import log_loss
from sklearn.utils.class_weight import compute_sample_weight
import optuna
from optuna.integration import SkorchPruningCallback
from optuna.pruners import MedianPruner, NopPruner
from optuna.samplers import RandomSampler, BaseSampler, TPESampler
from optuna.trial import TrialState
from sklearn.utils import check_random_state
import torch
from torch.utils.data import Subset
from torch.utils.data._utils.collate import default_collate
from skorch.callbacks import Checkpoint, LoadInitState
from skorch.helper import predefined_split

from braindecode.augmentation.base import Transform, Compose,\
    IdentityTransform, BaseDataLoader
from braindecode.augmentation.functionals import identity
from braindecode import EEGClassifier

from eeg_augment.training_utils import (
    fit_and_predict, CrossvalModel, make_args_parser, handle_dataset_args,
    compute_class_weights_dict, prepare_training, read_config,
    set_random_seeds,
)
from eeg_augment.utils import (
    POSSIBLE_TRANSFORMS, get_global_rngs_states, set_global_rngs_states
)


# PHYSIONET_ALLOWED_TRANSFORMS = list(POSSIBLE_TRANSFORMS.keys())
# PHYSIONET_ALLOWED_TRANSFORMS.remove('channel-sym')

PHYSIONET_BEST_TRANSFORMS = ['flip', 'sign', 'ft-surrogate', 'freq-shift']
MASS_BEST_TRANSFORMS = PHYSIONET_BEST_TRANSFORMS + ['channel-sym', 'time-mask',
                                                    'bandstop']
GLOBAL_TRANSFORMS_SUBSET = ['flip', 'sign', 'ft-surrogate', 'freq-shift',
                            'channel-sym']

SEARCH_ALGOS = {
    'random-search': RandomSampler,
    'tpe': TPESampler,
}


def load_retrain_results_if_exist(training_dir, fold=None, epoch=None):
    """ Loads last line of retraining results if one exists in the correct
    path and infers current step
    """
    save_path = Path(training_dir) / 'search_perf_results.csv'
    step_idx = 0
    idx = -1
    if save_path.exists():
        prev_results = pd.read_csv(save_path)
        # Only consider a given fold one passed
        if fold is not None:
            fold_mask = prev_results["fold"] == fold
            prev_results = prev_results.loc[fold_mask, :].reset_index(
                drop=True
            )
        # Load last row corresponding to desired epoch when passed
        if epoch is not None:
            matches = prev_results.index[
                prev_results["tot_search_epochs"] == epoch
            ].tolist()
            idx = matches[-1]
        last_results = prev_results.iloc[idx, :]
        step_idx = last_results["step_idx"] + 1
        return step_idx, last_results
    return step_idx, None


def _check_transform_tuple(transform_tuple):
    """Check the consistency of a tuple (transform_name, probability,
    magnitude) and outputs in form for easy instantiation

    Parameters
    ----------
    transform_tuple : tuple
        Should be a tuple of the form (transform_name, probability, magnitude).

    Returns
    -------
    str
        transform_name
    dict
        Containing probability and magnitude.
    """
    assert len(transform_tuple) == 3,\
        f"transform_tuples should be of length 3. Got {len(transform_tuple)}"
    transform_key, probability, magnitude = transform_tuple
    assert isinstance(transform_key, str),\
        "First element of transform_tuple should be a string. " +\
        f"Got {transform_key}."
    assert (
        transform_key == "randaugment" or transform_key in POSSIBLE_TRANSFORMS
    ), (
        "transforms must be one or more among ",
        str(POSSIBLE_TRANSFORMS.keys()), f"Got {transform_key}."
    )
    assert isinstance(probability, Real) and 0 <= probability <= 1,\
        f"probability must be a float between 0 and 1. Got {probability}."
    transform_set_params = {'probability': probability}
    if isinstance(magnitude, Real) and 0 <= magnitude <= 1:
        transform_set_params['magnitude'] = magnitude
    elif magnitude is not None:  # UNTESTED
        raise ValueError(
            f"magnitude must be a float between 0 and 1. Got {magnitude}"
        )
    return transform_key, transform_set_params


def _make_transform_objects(
    transform_tuples,
    ordered_ch_names,
    sfreq,
    randaugment_seq_len=None,
    randaugment_transform_collec=None,
    defaults_key="default",
    random_state=None,
):
    """Create list of transform objects our of list of tuples of the form
    (transform_name, probability, magnitude)

    Parameters
    ----------
    transform_tuple : tuple
        Should be a tuple of the form (transform_name, probability, magnitude),
        where transform_name is a str listed in
        eeg_augment.utils.POSSIBLE_TRANSFORMS and probability and magnitude are
        floats between 0 and 1.
    ordered_ch_names : list
        List of strings representing the channels of the montage considered.
        Only used for instantiating transforms needing this information. Has to
        be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channel will be used to compute approximate sensors
        positions from a standard 10-20 montage.
    sfreq : int
        Sampling frequency. Used to instantiate Transforms such as
        FrequencyShift and BandstopFilter.
    randaugment_seq_len : int, optional
        Length of transforms sequence to sample in case RandAugment is used.
    randaugment_transform_collec : list | None, optional
        Collection of possible transforms to sample from when using
        RandAugment. Defaults to None (all available transforms).
    defaults_key : str, optional
        What default parameters to fetch. Can be "default", "edf" (for
        Physionet dataset default magnitudes) or "mass" (for MASS dataset
        default magnitudes). Defaults to "default".
    random_state : int | numpy.random.Generator, optional
        Used to create or as random number generator within transforms created.
        Defaults to None.
    """
    if transform_tuples is None:  # UNTESTED
        return IdentityTransform()
    if isinstance(transform_tuples, tuple) and len(transform_tuples) == 3:
        transform_tuples = [transform_tuples]

    # We instantiate the rng passed to all composed Transforms as soon as
    # possible, so that instantiated operations affect each other and don't
    # evolve based on identical RNGs evolving in parallel
    rng = check_random_state(random_state)

    if isinstance(transform_tuples, list):
        transforms = list()
        for transform_tuple in transform_tuples:
            transform_key, transform_set_params = _check_transform_tuple(
                transform_tuple
            )
            # TODO: Ideally, this randaugment branches should be removed when I
            # can do everything with BaseAugmentationSearcher class
            # Seperate branch for the RandAugment case
            if transform_key == "randaugment":
                transform_class = RandAugment
                transform_params = {
                    'sequence_length': randaugment_seq_len,
                    'transforms': randaugment_transform_collec,
                }
            else:
                transform_class, transform_defaults = POSSIBLE_TRANSFORMS[
                    transform_key
                ]
                transform_params = transform_defaults.get(
                    defaults_key,
                    transform_defaults.get("default", None)
                )
            # Update default params with those passed as arguments
            # Will override magnitudes and probabilities if those are passed
            transform_params.update(transform_set_params)
            if (
                "ordered_ch_names" in signature(transform_class).parameters and
                ordered_ch_names is not None
            ):
                transform_params["ordered_ch_names"] = ordered_ch_names
            if (
                "sfreq" in signature(transform_class).parameters and
                sfreq is not None
            ):
                transform_params["sfreq"] = sfreq
            print(
                f"Instantiating transform {transform_key}",
                f"with kwargs: {transform_params}"
            )
            transforms.append(
                transform_class(random_state=rng, **transform_params)
            )
        return transforms
    else:
        raise ValueError(
            'When not omitted, transform_tuples should be a tuple or list of '
            'tuples of the form (transform_name, probability, magnitude) and '
            'type (str, float, float).'
        )


def sample_subpolicy(
    trial,
    subpolicies_length,
    transforms_family,
    ordered_ch_names,
    sfreq,
    random_state,
    suffix=None,
):
    # To add class labels when necessary
    if suffix is None:
        suffix = ''

    # Sample parameters
    sampled_subpolicy = []
    for i in range(subpolicies_length):
        operation = trial.suggest_categorical(
            f'operation{suffix}_{i}', transforms_family
        )
        # Upper-bound is not included here (unless I make it discrete)
        probability = trial.suggest_float(f'probability{suffix}_{i}', 0, 1)
        magnitude = trial.suggest_float(f'magnitude{suffix}_{i}', 0, 1)
        sampled_subpolicy.append((operation, probability, magnitude))

    # Use parameters to create transforms and put them in a list
    subpolicy_transforms = _make_transform_objects(
        sampled_subpolicy,
        ordered_ch_names=ordered_ch_names,
        sfreq=sfreq,
        random_state=random_state,
    )

    # Use the list to instantiante a subpolicy (Compose) and return
    return Compose(subpolicy_transforms)


def sample_classwise_subpolicy(
    trial,
    classes,
    subpolicies_length,
    transforms_family,
    ordered_ch_names,
    sfreq,
    random_state,
):
    rng = check_random_state(random_state)
    subpolicy_per_class = {
        c: sample_subpolicy(
            trial=trial,
            subpolicies_length=subpolicies_length,
            transforms_family=transforms_family,
            ordered_ch_names=ordered_ch_names,
            sfreq=sfreq,
            random_state=rng,
            suffix=f"_class{c}"
        ) for c in classes
    }
    return ClasswiseSubpolicy(subpolicy_per_class)


def make_train_and_see_objective(
    model,
    train_set,
    valid_set,
    ordered_ch_names,
    sfreq,
    model_params,
    shared_callbacks,
    pretrain_path=None,
    subpolicies_length=2,
    transforms_family=None,
    metric_to_monitor="valid_bal_acc_best",
    balanced_loss=True,
    epochs=300,
    classes=None,
    classwise=False,
    random_state=None,
):
    """Objective factory where for each new trial, the augmentations are used
    to retrain the model, which is then evaluated on the validation set

    Parameters
    ----------
    model : torch.nn.Module
        torch Module modeling a neural net classifier.
    train_set : torch.util.data.Dataset
        Dataset used for training the model.
    valid_set : torch.util.data.Dataset
        Dataset used for earlystopping and for assesssing the subpolicy used
        during trained.
    ordered_ch_names : list
        List of strings representing the channels of the montage considered.
        Only used for instantiating transforms needing this information. Has to
        be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channel will be used to compute approximate sensors
        positions from a standard 10-20 montage.
    sfreq : int
        Sampling frequency. Used to instantiate Transforms such as
        FrequencyShift and BandstopFilter.
    model_params : dict
        Parameters to pass to skorch.classifier.NeuralNetClassifer
        constructor (other than callbacks and splits). By default None.
        See https://skorch.readthedocs.io/en/stable/classifier.html for
        more details.
    shared_callbacks : list
        List of callback objects (torch or skorch) shared across all
        folds (or of tuples (`str`, `Callback`)). Should not be used for
        logging/checkpointing/loading callbacks ; see other args for this.
    pretrain_path : str, optional
        Path to training folder where the checkpoint of a trained model is, so
        that it can be used to warmstart the algorithm.
    subpolicies_length : int, optional
        Length of subpolicies to look for (consecutive Transform objects).
        Defaults to 2.
    transforms_family : list, optional
        List fo strings encoding transformations from
        eeg_augment.utils.POSSIBLE_TRANSFORMS to sample from. If omitted, all
        transformations will be considered. Defaults to None.
    metric_to_monitor : str, optional
        Metric to use for prunning. By default 'valid_bal_acc_best'.
    balanced_loss : boolean, optional
        Whether to balance the passed loss with the classes frequencies
        from the split training sets. By default True.
    epochs : int, optional
        Maximum number of epochs to train in each split. By default 300.
    random_state : int | np.random.RandomState | None, optional
        Used for seeding random number generator.

    Returns
    -------
    callable
        Objective function taking a trial as input and outputing a validation
        loss.
    """
    if transforms_family is None:
        transforms_family = list(POSSIBLE_TRANSFORMS.keys())
        transforms_family.remove('no-aug')

    if classes is None:
        classes = list(range(5))

    def _objective(trial):
        if classwise:
            # Sample subpolicies per class
            subpolicy = sample_classwise_subpolicy(
                trial=trial,
                classes=classes,
                subpolicies_length=subpolicies_length,
                transforms_family=transforms_family,
                ordered_ch_names=ordered_ch_names,
                sfreq=sfreq,
                random_state=random_state,
            )
        else:
            # Sample just one subpolicy
            subpolicy = sample_subpolicy(
                trial=trial,
                subpolicies_length=subpolicies_length,
                transforms_family=transforms_family,
                ordered_ch_names=ordered_ch_names,
                sfreq=sfreq,
                random_state=random_state,
            )

        model_params['iterator_train__transforms'] = subpolicy

        callbacks = list()
        callbacks += shared_callbacks
        callbacks.append(
            ("pruning", SkorchPruningCallback(trial, metric_to_monitor))
        )

        if pretrain_path is not None:
            checkpoint = Checkpoint(pretrain_path)
            warmstart = LoadInitState(checkpoint)
            callbacks.append(("warmstart", warmstart))

        predicted_probas, class_weights = fit_and_predict(
            model=deepcopy(model),
            train_set=train_set,
            valid_set=valid_set,  # Used for earlystopping
            test_set=valid_set,   # On purpose: used for assessing trial
            epochs=epochs,
            model_params=model_params,
            balanced_loss=balanced_loss,
            callbacks=callbacks,
        )

        y_true = np.array([y for _, y, _ in valid_set])
        y_proba = predicted_probas['test']
        test_weights = compute_sample_weight(class_weights, y_true)
        return log_loss(y_true, y_proba, sample_weight=test_weights)
    return _objective


def make_approx_density_match_objective(
    model,
    train_set,
    valid_set,
    ordered_ch_names,
    model_params,
    pretrain_path,
    sfreq,
    shared_callbacks=None,
    subpolicies_length=2,
    transforms_family=None,
    metric_to_monitor="valid_bal_acc_best",
    balanced_loss=True,
    epochs=300,
    classes=None,
    classwise=False,
    random_state=None,
):
    """Objective factory where for each new trial, the augmentations are used
    to retrain the model, which is then evaluated on the validation set

    Parameters
    ----------
    model : torch.nn.Module
        torch Module modeling a neural net classifier.
    train_set : torch.util.data.Dataset
        Dataset used for training the model.
    valid_set : torch.util.data.Dataset
        Dataset used for earlystopping and for assesssing the subpolicy used
        during trained.
    ordered_ch_names : list
        List of strings representing the channels of the montage considered.
        Only used for instantiating transforms needing this information. Has to
        be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channel will be used to compute approximate sensors
        positions from a standard 10-20 montage.
    model_params : dict
        Parameters to pass to skorch.classifier.NeuralNetClassifer
        constructor (other than callbacks and splits). By default None.
        See https://skorch.readthedocs.io/en/stable/classifier.html for
        more details.
    pretrain_path : str
        Path to training folder where the checkpoint of a trained model is, so
        that it can be used to warmstart the algorithm.
    sfreq : int, optional
        Always ignored, exists for compatibility.
    shared_callbacks : list, optional
        Always ignored, exists for compatibility.
    subpolicies_length : int, optional
        Length of subpolicies to look for (consecutive Transform objects).
        Defaults to 2.
    transforms_family : list, optional
        List fo strings encoding transformations from
        eeg_augment.utils.POSSIBLE_TRANSFORMS to sample from. If omitted, all
        transformations will be considered. Defaults to None.
    metric_to_monitor : str, optional
        Always ignored, exists for compatibility.
    balanced_loss : boolean, optional
        Whether to balance the passed loss with the classes frequencies
        from the split training sets. By default True.
    epochs : int, optional
        Always ignored, exists for compatibility.
    random_state : int | np.random.RandomState | None, optional
        Used for seeding random number generator.

    Returns
    -------
    callable
        Objective function taking a trial as input and outputing a validation
        loss.
    """
    clf = EEGClassifier(model)
    clf.initialize()
    if pretrain_path is not None:
        clf.load_params(f_params=join(pretrain_path, "params.pt"))
    # Unfortunately, Skorch moves all models loaded to cpu, unless device is
    # specified together with module. Changing the EEGClassifier later does not
    # update the module device
    clf.module = clf.module.to(model_params['device'])

    if transforms_family is None:
        transforms_family = list(POSSIBLE_TRANSFORMS.keys())
        transforms_family.remove('no-aug')

    if classes is None:
        classes = list(range(5))

    def _objective(trial):
        if classwise:
            # Sample subpolicies per class
            subpolicy = sample_classwise_subpolicy(
                trial=trial,
                classes=classes,
                subpolicies_length=subpolicies_length,
                transforms_family=transforms_family,
                ordered_ch_names=ordered_ch_names,
                sfreq=sfreq,
                random_state=random_state,
            )
        else:
            # Sample just one subpolicy
            subpolicy = sample_subpolicy(
                trial=trial,
                subpolicies_length=subpolicies_length,
                transforms_family=transforms_family,
                ordered_ch_names=ordered_ch_names,
                sfreq=sfreq,
                random_state=random_state,
            )
        # Note that we set the **validation** iterator here
        classifier_params = {
            'train_split': predefined_split(valid_set),
            'iterator_valid': BaseDataLoader,
            'iterator_valid__transforms': subpolicy,
        }
        classifier_params.update(model_params)
        clf.set_params(**classifier_params)
        class_weights_dict = compute_class_weights_dict(
            train_set, balanced_loss
        )
        with torch.no_grad():
            y_proba = clf.predict_proba(valid_set)

        y_true = np.array([y for _, y, _ in valid_set])
        test_weights = compute_sample_weight(class_weights_dict, y_true)
        return log_loss(y_true, y_proba, sample_weight=test_weights)
    return _objective


SEARCH_METRICS = {
    'autoaug': make_train_and_see_objective,
    'match': make_approx_density_match_objective,
}


class BaseAugmentationSearcher(CrossvalModel):
    """Class useful to search augmentation policies with desired search
    algorithm and metric and to assess the learned policy with cross-validation

    Parameters
    ----------
    training_dir : str
        Path to directory where elements should be saved.
    model : torch.nn.Module
        torch Module modeling a neural net classifier.
    subpolicies_length: int, optional
        Length of subpolicies to look for (consecutive Transform objects).
        Defaults to 2.
    policy_size_per_fold: int, optional
        Number of best subpolicies to select per fold. Defaults to 10.
    transforms_family : list, optional
        List fo strings encoding transformations from
        eeg_augment.utils.POSSIBLE_TRANSFORMS to sample from. If omitted, all
        transformations will be considered. Defaults to None.
    model_params : dict, optional
        Parameters to pass to skorch.classifier.NeuralNetClassifer
        constructor (other than callbacks and splits). By default None.
        See https://skorch.readthedocs.io/en/stable/classifier.html for
        more details.
    n_folds : int, optional
        Number of folds for cross-validation. It will generate n_fold
        triplets (train, valid, test) sets. By default 5.
    shared_callbacks : list, optional
        List of callback objects (torch or skorch) shared across all
        folds (or of tuples (`str`, `Callback`)). Should not be used for
        logging/checkpointing/loading callbacks ; see other args for this.
        By default None.
    balanced_loss : boolean, optional
        Whether to balance the passed loss with the classes frequencies
        from the split training sets. By default True.
    monitor : str, optional
        Metric to use for checkpointing. By default 'valid_loss_best'.
    should_checkpoint : boolean, optional
        Whether to save model checkpoints (including training history) into
        fold-specific folders: training_dir/foldnofN.
        Will save based on improvements in metric self.monitor and when
        training stops. By default True.
    should_load_state : boolean, optional
        Whether to load states from latest checkpoint. Works fold-wise.
        By default True.
    log_tensorboard : boolean, optional
        Whether to log metrics for monitoring on Tensorboard. Will be saved
        in training_dir/foldnofN/logs when set to True. By default True.
    train_size_over_valid : float, optional
        Float between 0 and 1 setting the ratio used to split each training
        fold into a train and a validation set. By default 0.8.
    random_state : int | None, optional
        Used for seeding random number generator.
    """
    def __init__(
        self,
        training_dir,
        model,
        subpolicies_length=2,
        policy_size_per_fold=10,
        transforms_family=None,
        random_state=None,
        **kwargs
    ):
        assert isinstance(subpolicies_length, int) and subpolicies_length > 0,\
            "subpolicies_length has to be a positive int."
        self.subpolicies_length = subpolicies_length
        assert (
            isinstance(policy_size_per_fold, int) and policy_size_per_fold > 0
        ) or policy_size_per_fold is None,\
            "policy_size_per_fold has to be a positive int."
        self.policy_size_per_fold = policy_size_per_fold
        assert (
            isinstance(transforms_family, list) or transforms_family is None
        ), "transforms_family should be None or a list."
        self.transforms_family = transforms_family
        self.learned_policies = dict()
        self.retraining_seed = random_state  # Used to assess policies learned
        # Attribute used for storing fold-specific RandomState objects and
        # global rng states
        self.random_states = defaultdict(dict)
        super().__init__(training_dir, model, random_state=random_state,
                         **kwargs)

    @property
    def device(self):
        return self.model_params["device"]

    @property
    def cuda(self):
        return self.device.type == "cuda"

    def _init_global_and_specific_rngs(self, fold, random_state):
        """Seeds global rngs and random state object (stored as an attribute
        for persistence across search calls).

        Parameters
        ----------
        fold : int
            Fold corresponding to the RNGs.
        random_state : int
            Seed to use for all RNGs.
        """
        set_global_rngs_states(
            seed=random_state,
            cuda=self.cuda,
        )
        if "random_state_obj" not in self.random_states[fold]:
            self.random_states[fold]["random_state_obj"] = check_random_state(
                random_state
            )

    def _set_global_rngs_from_previous_calls(
        self,
        fold,
    ):
        """Sets global RNGs using states from previous calls when available.
        Important function to ensure persistent randomness and
        DA search robustness to different step sizes (number of search epochs
        between each evaluation).
        """
        # Set python, numpy and torch global rngs using previous states
        # when available, as using seed (random_state) when not
        saved_states = self.random_states[fold]
        if "global_rng_states" in saved_states:
            set_global_rngs_states(
                states=saved_states["global_rng_states"],
                cuda=self.cuda,
            )

    def _save_current_global_rng_states(self, fold):
        """Fetches the states of global RNGs of torch, cuda, numpy and python
        and stores them in the attribute self.random_states
        """
        self.random_states[fold]["global_rng_states"] = get_global_rngs_states(
            cuda=self.cuda,
        )

    def _get_warmstart_path(self, warmstart_base_path, fold, subset_ratio):
        warmstart_path = None
        if warmstart_base_path is not None:
            fold_path = join(
                warmstart_base_path, f'fold{fold}of{self.n_folds}'
            )
            warmstart_path = join(fold_path, f'subset_{subset_ratio}_samples')
        return warmstart_path

    def _search_in_fold(
        self,
        split,
        random_state,
        epochs,
        windows_dataset,
        *args,
        **kwargs
    ):
        policy_found = None
        fold = None
        return fold, policy_found

    def search_policy(
        self,
        windows_dataset,
        epochs,
        data_ratio=None,
        grouped_subset=True,
        n_jobs=1,
        verbose=False,
        **kwargs
    ):
        # TODO: Explain how the windows_dataset is split
        """Search for best augmentation policy

        Parameters
        ----------
        windows_dataset : torch.data.utils.ConcatDataset
            Dataset to fit.
        epochs : int
            Number of epochs.
        data_ratio : float | None, optional
            Float between 0 and 1 or None. Will be used to build a
            subset of the cross-validated training sets (valid and test sets
            are conserved). Omitting it or setting it to None, is equivalent to
            setting it to [1.] (using the whole training set). By default None.
        grouped_subset : bool, optional
            Whether to compute training subsets taking groups (subjects) into
            account or not. When False, stratified spliting will be used to
            build the subsets. By default True.
        n_jobs : int, optional
            Number of workers to use for parallelizing across splits. By
            default 1.
        verbose : bool, optional
            By default False.
        """
        assert isinstance(data_ratio, float) or data_ratio is None,\
            "Only a single data_ratio value is supported for now."
        folds_policies_tuples = self._crossval_apply(
            windows_dataset,
            epochs,
            function_to_apply=self._search_in_fold,
            data_ratios=data_ratio,
            grouped_subset=grouped_subset,
            n_jobs=n_jobs,
            verbose=verbose,
            **kwargs
        )
        self.learned_policies = dict(folds_policies_tuples)

    def _fit_and_score(
        self,
        split,
        epochs,
        model_params,
        windows_dataset,
        random_state=None,
        **kwargs
    ):
        """Train and tests a copy of self.model on the desired split using the
        learned policy

        Parameters
        ----------
        split : tuple
            Tuple containing the fold index, the training set proportion and
            the indices of the training, validation and test set.
        epochs : int
            Maximum number of epochs for the training.
        windows_dataset : torch.utils.data.Dataset
            Dataset that will be split and used for training, validation and
            tetsing.
        model_params : dict
            Modified copy of self.model_params.
        random_state : int | None
            Seed to use for RNGs.

        Returns
        -------
        dict
            Dictionary containing the balanced accuracy, the loss, the kappa
            score and the confusion matrix for the training, validationa and
            test sets.
        """
        fold = split[0]
        augmentation_policy = self.learned_policies[fold]
        model_params["iterator_train__transforms"] = augmentation_policy
        return super()._fit_and_score(
            split=split,
            random_state=random_state,
            epochs=epochs,
            windows_dataset=windows_dataset,
            model_params=model_params,
            **kwargs
        )


def start_db(db_file):
    """ Starts a SQLite database and creates connection """
    conn = None
    try:
        conn = sqlite3.connect(db_file)
        return conn
    except Error as e:
        print(e)
        close_db(conn)
        return None


def close_db(conn):
    """ Closes connexion to SQLite database """
    if conn:
        conn.close()


def derive_study_name(experiment_path):
    """ Take path to experiment folder and returns folder name to create a
    study with
    """
    return basename(experiment_path)


class BlackBoxAugmentationSearcher(BaseAugmentationSearcher):
    """Class useful to search augmentation policies with desired discrete
    search algorithm and metric and to assess the learned policy with
    cross-validation

    Parameters
    ----------
    training_dir : str
        Path to directory where elements should be saved.
    model : torch.nn.Module
        torch Module modeling a neural net classifier.
    objective_factory : callable
        Factory function which outputs an objective function with the
        signature optuna.trial -> float, using the following parameters:
        * training_path (str): where to save checkpoints,
        * model (pytorch.nn.Module): model to use for training and evaluation,
        * train_set (pytorch.util.data.Dataset): training set,
        * valid_set (pytorch.util.data.Dataset): validation set,
        * ordered_ch_names (list of str): ordered list of channels in dataset,
        * pretrain_path (str): where to fetch weights to warmstart the model.
        * sfreq (float): Sampling frequency.
        * model_params (dict): parameters to pass to braindecode.EEGClassifier,
        * shared_callbacks (list): callbacks shared across splits,
        * subpolicies_length (int): length of subpolicies to look for,
        * transforms_family (list): transforms to sample from,
        * metric_to_monitor (str): metric used to prune,
        * balanced_loss (bool): whether to train with balanced loss,
        * epochs (int): maximum number of training epochs,
        * random_state (int | np.random.RandomState): Random state.
        See `make_train_and_see_objective` for an example.
    sampler: class inheriting from optuna.samplers.BaseSampler
        Search algorithm to use to find the best subpolicies.
    subpolicies_length: int, optional
        Length of subpolicies to look for (consecutive Transform objects).
        Defaults to 2.
    n_trials: int, optional
        Number of iterations allowed to the search algorithm. Defaults to 100.
    policy_size_per_fold: int, optional
        Number of best subpolicies to select per fold. Defaults to 10.
    transforms_family : list, optional
        List fo strings encoding transformations from
        eeg_augment.utils.POSSIBLE_TRANSFORMS to sample from. If omitted, all
        transformations will be considered. Defaults to None.
    model_params : dict, optional
        Parameters to pass to skorch.classifier.NeuralNetClassifer
        constructor (other than callbacks and splits). By default None.
        See https://skorch.readthedocs.io/en/stable/classifier.html for
        more details.
    shared_callbacks : list, optional
        List of callback objects (torch or skorch) shared across all
        folds (or of tuples (`str`, `Callback`)). Should not be used for
        logging/checkpointing/loading callbacks ; see other args for this.
        By default None.
    balanced_loss : boolean, optional
        Whether to balance the passed loss with the classes frequencies
        from the split training sets. By default True.
    monitor : str, optional
        Metric to use for checkpointing. By default 'valid_loss_best'.
    should_checkpoint : boolean, optional
        Whether to save model checkpoints (including training history) into
        fold-specific folders: training_dir/foldnofN.
        Will save based on improvements in metric self.monitor and when
        training stops. By default True.
    should_load_state : boolean, optional
        Whether to load states from latest checkpoint. Works fold-wise.
        By default True.
    log_tensorboard : boolean, optional
        Whether to log metrics for monitoring on Tensorboard. Will be saved
        in training_dir/foldnofN/logs when set to True. By default True.
    train_size_over_valid : float, optional
        Float between 0 and 1 setting the ratio used to split each training
        fold into a train and a validation set. By default 0.8.
    random_state : int | None, optional
        Used for seeding random number generator.
    """
    def __init__(
        self,
        training_dir,
        model,
        objective_factory,
        sampler,
        subpolicies_length=2,
        policy_size_per_fold=10,
        transforms_family=None,
        classwise=False,
        classes=None,
        **kwargs
    ):
        super().__init__(
            training_dir=training_dir,
            model=model,
            subpolicies_length=subpolicies_length,
            policy_size_per_fold=policy_size_per_fold,
            transforms_family=transforms_family,
            **kwargs
        )
        assert callable(objective_factory) or objective_factory is None,\
            "objective_factory should be callable."
        self.objective_factory = objective_factory
        assert sampler is None or (
            isclass(sampler) and isinstance(sampler(), BaseSampler)
        ), "sampler must be class inheriting from optuna.samplers.BaseSampler"
        self.sampler = sampler
        self.classwise = classwise
        self.classes = classes

    def _from_trial_to_classwise_subpolicy(
        self,
        trial,
        ordered_ch_names,
        sfreq,
        random_state,
    ):
        subpolicies_per_class = {
            c: self._from_trial_to_subpolicy(
                trial=trial,
                ordered_ch_names=ordered_ch_names,
                sfreq=sfreq,
                suffix=f'_class{c}',
                random_state=random_state,
            ) for c in self.classes
        }
        return ClasswiseSubpolicy(subpolicies_per_class)

    def _from_trial_to_subpolicy(
        self,
        trial,
        ordered_ch_names,
        sfreq,
        random_state,
        suffix=None,
    ):
        if suffix is None:
            suffix = ""

        triplets = [(
                trial.params[f'operation{suffix}_{i}'],
                trial.params[f'probability{suffix}_{i}'],
                trial.params[f'magnitude{suffix}_{i}']
        ) for i in range(self.subpolicies_length)]

        transforms = _make_transform_objects(
            transform_tuples=triplets,
            ordered_ch_names=ordered_ch_names,
            sfreq=sfreq,
            random_state=random_state,
        )
        return Compose(transforms)

    def _search_in_fold(
        self,
        split,
        random_state,
        epochs,
        windows_dataset,
        ordered_ch_names,
        sfreq,
        n_trials=100,
        pruning=True,
        pretrain_base_path=None,
    ):
        assert (
            isinstance(n_trials, int) and n_trials > 0
        ) or n_trials is None, "n_trials has to be a positive int."

        fold, subset_ratio, train_subset_idx, valid_idx, _ = split

        # Seed global RNGs and local RndomState object for this fold
        self._init_global_and_specific_rngs(
            fold=fold,
            random_state=random_state
        )

        # Make folder where to save all results
        fold_path = join(self.training_dir, f'fold{fold}of{self.n_folds}')
        subset_path = join(fold_path, f'subset_{subset_ratio}_samples')
        os.makedirs(subset_path, exist_ok=True)

        # Use indices to split training and validation sets
        train_subset = Subset(windows_dataset, train_subset_idx)
        valid_set = Subset(windows_dataset, valid_idx)

        print(
            f"---------- Fold {fold} out of {self.n_folds} |",
            f"Training size: {len(train_subset)} ----------"
        )

        # Find path to checkpoint (when needed)
        warmstart_path = self._get_warmstart_path(
            pretrain_base_path, fold, subset_ratio
        )

        # Define the objective
        model = deepcopy(self.model)  # Important: model restarted every trial

        objective = self.objective_factory(
            model=model,
            train_set=train_subset,
            valid_set=valid_set,
            ordered_ch_names=ordered_ch_names,
            sfreq=sfreq,
            pretrain_path=warmstart_path,
            model_params=self.model_params,
            shared_callbacks=self.shared_callbacks,
            subpolicies_length=self.subpolicies_length,
            transforms_family=self.transforms_family,
            balanced_loss=self.balanced_loss,
            metric_to_monitor=self.monitor,
            epochs=epochs,
            classes=self.classes,
            classwise=self.classwise,
            random_state=self.random_states[fold]["random_state_obj"],
        )
        pruner = MedianPruner() if pruning else NopPruner()

        try:
            # Start or load a db to make experiment persistent
            study_name = derive_study_name(self.training_dir)
            db_file = join(subset_path, study_name + '.db')
            conn = start_db(db_file)
            storage = optuna.storages.RDBStorage(
                url=f"sqlite:///{db_file}",
                engine_kwargs={"connect_args": {"timeout": 10}},
            )

            # Set global and specific RNGs for this fold
            self._set_global_rngs_from_previous_calls(fold=fold)

            # Start or load study and optimize
            study = optuna.create_study(
                direction='minimize',
                pruner=pruner,
                study_name=study_name,
                storage=storage,
                sampler=self.sampler(),
                load_if_exists=True,
            )

            study.sampler._rng = self.random_states[fold]["random_state_obj"]
            try:
                study.sampler._random_sampler._rng = self.random_states[fold][
                    "random_state_obj"
                ]
            except AttributeError:
                pass

            study.optimize(objective, n_trials=n_trials)

            # Store in memory the current global RNG states for future search
            # steps
            self._save_current_global_rng_states(fold=fold)

            # Save trials to dataframe
            study.trials_dataframe().to_pickle(join(subset_path, "trials.pkl"))

            all_finished_trials = [
                trial for trial in study.trials
                if trial.state == TrialState.COMPLETE
            ]
            sorted_trials = sorted(all_finished_trials, key=lambda x: x.value)

            # We start a new RNG with the initial seed to instantiate policies
            # and transforms used for the retraining. This leads to the same
            # random transforms for every new retraining, avoinding some
            # reproducibility issues
            retrain_rng = check_random_state(self.retraining_seed)

            # pick best subpolicies and return them as the learned policy
            if self.classwise:
                best_subpolicies = [
                    self._from_trial_to_classwise_subpolicy(
                        trial=trial,
                        ordered_ch_names=ordered_ch_names,
                        sfreq=sfreq,
                        random_state=retrain_rng,
                    ) for trial in sorted_trials[:self.policy_size_per_fold]
                ]
            else:
                best_subpolicies = [
                    self._from_trial_to_subpolicy(
                        trial=trial,
                        ordered_ch_names=ordered_ch_names,
                        sfreq=sfreq,
                        random_state=retrain_rng,
                    ) for trial in sorted_trials[:self.policy_size_per_fold]
                ]

            # Close connection to db
            close_db(conn)

            return fold, AugmentationPolicy(
                best_subpolicies, random_state=retrain_rng,
            )
        except Error as e:
            print(e)
            close_db(conn)


def sample_subpolicy_and_apply(
    X, y,
    subpolicies,
    random_state,
    batchwise=False,
    *args,
    **kwargs,
):
    rng = check_random_state(random_state)
    if not batchwise:
        batch_size = X.shape[0]
        # sample one subpolicy per sample
        sampled_subpolicies = rng.choice(
            subpolicies, size=batch_size
        )

        # Loops across batch samples and applies sampled subpolicies
        unrolled_batch = [
            [
                out.squeeze() for out in sampled_subpolicy(
                    Xi.unsqueeze(0),  # to simulate batch_size = 1
                    yi.unsqueeze(0),  # to simulate batch_size = 1
                )
            ] for sampled_subpolicy, Xi, yi in zip(sampled_subpolicies, X, y)
        ]
        return default_collate(unrolled_batch)  # aggregates list of outputs
    else:
        sampled_subpolicy = rng.choice(subpolicies)
        return sampled_subpolicy(X, y)


class AugmentationPolicy(Transform):
    """Augmentation policy, i.e. set of augmentation subpolicies with equal
    probability of being sampled for each new training example

    Parameters
    ----------
    subpolicies : list
        Subpolicies to sample uniformly from.
    batchwise : bool, optional
        Whether to sample subpolicies per batch (True) or per sample (False).
        Defaults to True.
    random_state : int | numpy.random.Generator, optional
        Used to sample operations to apply at runtime. Defaults to None.
    """
    def __init__(self, subpolicies, batchwise=True, random_state=None):

        super().__init__(
            probability=1.0,
            operation=sample_subpolicy_and_apply,
            magnitude=None,
            mag_range=None,
            subpolicies=subpolicies,
            batchwise=batchwise,
            random_state=random_state
        )
        self.subpolicies = torch.nn.ModuleList(subpolicies)

    def get_structure(self):
        """ Returns a DataFrame describing the transforms making the object"""
        structure = list()
        for i, subpolicy in enumerate(self.subpolicies):
            subpol_struct = subpolicy.get_structure()
            subpol_struct["subpolicy_idx"] = i
            structure.append(subpol_struct)
        return pd.concat(structure, ignore_index=True)


class RandAugment(AugmentationPolicy):
    """RandAugment automatic augmentation policy

    As proposed in [1]_

    Parameters
    ----------
    magnitude : float
        Magnitude for all the transformations. Must be between 0 and 1.
    sequence_length : int
        Number of transformations to apply sequentially.
    ordered_ch_names : list
        List of strings representing the channels of the montage considered.
        Has to be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channels will be used to compute approximate sensors
        positions from a standard 10-20 montage when applicable.
    sfreq : int
        Sampling frequency. Used to instantiate Transforms such as
        FrequencyShift and BandstopFilter.
    probability : object, optional
        Always ignored, exists for compatibility.
    transforms : list, optional
        List fo strings encoding transformations from
        eeg_augment.utils.POSSIBLE_TRANSFORMS to sample from. If omitted, all
        transformations will be considered. Defaults to None.
    batchwise : bool, optional
        Whether to sample subpolicies per batch (True) or per sample (False).
        Defaults to False.
    random_state : int | numpy.random.Generator, optional
        Used to sample operations to apply at runtime. Defaults to None.

    References
    ----------
    .. [1] Cubuk, E. D., Zoph, B., Shlens, J., & Le, Q. V. (2020).
       Randaugment: Practical automated data augmentation with a
       reduced search space. In Proceedings of the IEEE/CVF Conference
       on Computer Vision and Pattern Recognition Workshops (pp. 702-703).

    """
    def __init__(
        self,
        magnitude,
        sequence_length,
        ordered_ch_names,
        sfreq,
        probability=None,
        transforms=None,
        batchwise=True,
        random_state=None
    ):
        if transforms is None:
            self.transforms = list(POSSIBLE_TRANSFORMS.keys())
        elif isinstance(transforms, list):  # UNTESTED
            self.transforms = transforms
        else:  # UNTESTED
            raise ValueError("transforms should be a list fo strings.")
        assert isinstance(sequence_length, int) and sequence_length > 0,\
            "sequence_length should be a positive int."
        self.sequence_length = sequence_length
        assert isinstance(ordered_ch_names, list),\
            "ordered_ch_names should be an ordered list of channel names."
        self.ordered_ch_names = ordered_ch_names

        subpolicies_triplets = [
            [(operation, 1.0, magnitude) for operation in operations]
            for operations in product(
                self.transforms,
                repeat=self.sequence_length
            )
        ]

        subpolicies = [
            Compose(_make_transform_objects(triplets, ordered_ch_names, sfreq))
            for triplets in subpolicies_triplets
        ]

        super().__init__(
            subpolicies=subpolicies,
            batchwise=batchwise,
            random_state=random_state
        )


class RandAugmentSearcher(BaseAugmentationSearcher):
    """BaseAugmentationSearcher class where the augmentation policy used is
    RandAugment (not learned from data).

    This allows to evaluate seemlessly the RandAugment algorithm using
    CrossvalModel splits.

    Parameters
    ----------
    training_dir : str
        Path to directory where elements should be saved.
    model : torch.nn.Module
        torch Module modeling a neural net classifier.
    subpolicies_length: int, optional
        Length of subpolicies to sample (consecutive Transform objects).
        Defaults to 2.
    magnitude : float
        Magnitude for all the transformations. Must be between 0 and 1.
    ordered_ch_names : list
        List of strings representing the channels of the montage considered.
        Has to be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channels will be used to compute approximate sensors
        positions from a standard 10-20 montage when applicable.
    sfreq : int
        Sampling frequency. Used to instantiate Transforms such as
        FrequencyShift and BandstopFilter.
    transforms_family : list, optional
        List fo strings encoding transformations from
        eeg_augment.utils.POSSIBLE_TRANSFORMS to sample from. If omitted, all
        transformations will be considered. Defaults to None.
    random_state : int | None, optional
        Used for seeding random number generator.
    model_params : dict, optional
        Parameters to pass to skorch.classifier.NeuralNetClassifer
        constructor (other than callbacks and splits). By default None.
        See https://skorch.readthedocs.io/en/stable/classifier.html for
        more details.
    n_folds : int, optional
        Number of folds for cross-validation. It will generate n_fold
        triplets (train, valid, test) sets. By default 5.
    shared_callbacks : list, optional
        List of callback objects (torch or skorch) shared across all
        folds (or of tuples (`str`, `Callback`)). Should not be used for
        logging/checkpointing/loading callbacks ; see other args for this.
        By default None.
    balanced_loss : boolean, optional
        Whether to balance the passed loss with the classes frequencies
        from the split training sets. By default True.
    monitor : str, optional
        Metric to use for checkpointing. By default 'valid_loss_best'.
    should_checkpoint : boolean, optional
        Whether to save model checkpoints (including training history) into
        fold-specific folders: training_dir/foldnofN.
        Will save based on improvements in metric self.monitor and when
        training stops. By default True.
    should_load_state : boolean, optional
        Whether to load states from latest checkpoint. Works fold-wise.
        By default True.
    log_tensorboard : boolean, optional
        Whether to log metrics for monitoring on Tensorboard. Will be saved
        in training_dir/foldnofN/logs when set to True. By default True.
    train_size_over_valid : float, optional
        Float between 0 and 1 setting the ratio used to split each training
        fold into a train and a validation set. By default 0.8.
    """
    def __init__(
        self,
        training_dir,
        model,
        magnitude,
        ordered_ch_names,
        sfreq,
        subpolicies_length=2,
        transforms_family=None,
        random_state=None,
        **kwargs
    ):
        super().__init__(
            training_dir,
            model,
            objective_factory=None,
            sampler=None,
            subpolicies_length=subpolicies_length,
            n_trials=None,
            policy_size_per_fold=None,
            transforms_family=transforms_family,
            random_state=random_state,
            **kwargs
        )
        self.learned_policies = [
            RandAugment(
                magnitude=magnitude,
                sequence_length=self.subpolicies_length,
                ordered_ch_names=ordered_ch_names,
                sfreq=sfreq,
                transforms=self.transforms_family,
                random_state=self.rng
            )
        ] * self.n_folds

    def search_policy(self, *args, **kwargs):
        """Not implemented in this class
        """
        # This avoids user from overwritting the value of learned_policies
        pass


class ClasswiseSubpolicy(Transform):
    """Classwise subpolicy (applies different subpolicies according to the
    class)

    Parameters
    ----------
    subpolicies_per_class : dict
        Dictionary mapping classes to subpolicies (Compose transform objects).
    """
    def __init__(self, subpolicies_per_class):
        assert isinstance(subpolicies_per_class, dict),\
            "subpolicies_per_class should be a dictionary"
        assert all(
            isinstance(subpolicy, Transform)
            for subpolicy in subpolicies_per_class.values()
        ), "elements of subpolicies_per_class should be Transforms"
        self.subpolicies_per_class = subpolicies_per_class
        super().__init__(operation=identity)

    def forward(self, X, y):
        tr_X = X.clone()
        # Used to check that all classes have been found
        masks_running_or = np.zeros_like(y).astype(bool)
        for c in self.subpolicies_per_class.keys():
            mask = y == c
            masks_running_or = np.logical_or(mask, masks_running_or)
            try:
                if any(mask):
                    tr_X[mask, ...], _ = self.subpolicies_per_class[c](
                        X[mask, ...], y[mask]
                    )
            except Exception as e:
                print("mask: ", mask)
                print("y: ", y)
                print("c: ", c)
                print("X: ", X.shape)
                raise e
        assert torch.all(masks_running_or), (
            "Some classes in the batch"
            f" ({y[np.logical_not(masks_running_or)]}) were not found within"
            " the subpolicy's branches "
            f"({list(self.subpolicies_per_class.keys())})"
        )
        return tr_X, y

    def get_structure(self):
        """ Returns a DataFrame describing the transforms making the object"""
        structure = list()
        for c, subpolicy in self.subpolicies_per_class.items():
            subpol_struct = subpolicy.get_structure()
            subpol_struct["class"] = c
            if isinstance(subpol_struct, dict):
                subpol_struct = pd.DataFrame([subpol_struct])
            structure.append(subpol_struct)
        return pd.concat(structure, ignore_index=True)


def split_trials_in_steps(tot_trials, step_size):
    """ Takes a total number of trials and split into steps """
    n_full_steps = int(tot_trials // step_size)
    splitted_trials = [step_size] * n_full_steps
    rest = tot_trials % step_size
    if rest > 0:
        splitted_trials += [rest]
    return np.array(splitted_trials)


def evaluate_discrete_policy_search(
    training_dir,
    epochs,
    windows_dataset,
    sfreq,
    sampler='random-search',
    metric='autoaug',
    classwise=False,
    classes=None,
    pretrain_base_path=None,
    subpolicies_length=2,
    n_trials=200,
    eval_step=10,
    policy_size_per_fold=5,
    n_classes=5,
    device=None,
    lr=1e-3,
    batch_size=128,
    num_workers=4,
    random_state=None,
    parallel=False,
    ordered_ch_names=None,
    early_stop=True,
    model_to_use=None,
    data_ratio=None,
    grouped_subset=True,
    n_jobs=1,
    verbose=False,
    **kwargs
):
    """Look for optimal policy using desired discrete search algorithm and
    metric and compute crossvalidated test performance every `eval_step` trials

    Parameters
    ----------
    training_dir : str
        Directory where checkpoints, search results and test set evaluation
        results are saved.
    epochs : int
        Number of epochs used during final training with learned policy (and
        during policy search depending on the metric chosen).
    windows_dataset : torch.util.data.Dataset
        Dataset to use for training, validation and test (after splitting).
    sfreq : float
        Sampling frequency of input data.
    sampler : str, optional
        Search algorithm to use to search for subpolicies (can be either
        'random_search' or 'tpe'), by default 'random_search'
    metric : str, optional
        Metric used to evaluate candidate subpolicies (can be 'autoaug' or
        'match'), by default 'autoaug'
    pretrain_base_path : str, optional
        Path to training folder where the checkpoints of a trained model are.
        Used to warmstart the model. Default to None.
    subpolicies_length : int, optional
        Number of consecutive Transforms in a subpolicy, by default 2
    n_trials : int, optional
        Budget of the search algorithm (in number of trials), by default 200.
    eval_step : int, optional
        Number of trials to do before training a model from scratch with the
        learned policy to assess its current performance. Defaults to 10.
    policy_size_per_fold : int, optional
        Size of policies (in number of subpolicies), by default 5
    n_classes : int, optional
        Number of classes in the dataset, by default 5
    device : str, optional
        Device to use for training, by default None
    lr : float, optional
        Learning rate, by default 1e-3
    batch_size : int, optional
        Batch size, by default 128
    num_workers : int, optional
        Number of workers used for data loading. By default 4
    random_state : int | numpy.random.RandomState | None, optional
        Used to seed random number generator, by default None
    parallel : bool, optional
        Whether to parallelize on several GPUs, by default False
    ordered_ch_names : list, optional
        List of strings representing the channels of the montage considered.
        Only used for instantiating transforms needing this information. Has to
        be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channel will be used to compute approximate sensors
        positions from a standard 10-20 montage. By default None.
    early_stop : bool, optional
        Whether to carry earlystopping during training. By default True
    model_to_use : str | None, optional
        Defines which net should be used. By default (None) will use
        SleepStager. If set to 'lin' will use one layer linear net.
    data_ratio : float | None, optional
        Float between 0 and 1 or None. Will be used to build a
        subset of the cross-validated training sets (valid and test sets
        are conserved). Omitting it or setting it to None, is equivalent to
        setting it to [1.] (using the whole training set). By default None.
    grouped_subset : bool, optional
        Whether to compute training subsets taking groups (subjects) into
        account or not. When False, stratified spliting will be used to
        build the subsets. By default True.
    n_jobs : int, optional
        Number of workers to use for parallelizing across splits. By
        default 1.
    verbose : bool, optional
        By default False.
    """
    # Parse arguments and prepare elements for the search and training
    device, _, model, model_params, shared_callbacks = prepare_training(
        windows_dataset=windows_dataset,
        lr=lr,
        batch_size=batch_size,
        num_workers=num_workers,
        early_stop=early_stop,
        sfreq=sfreq,
        n_classes=n_classes,
        parallel=parallel,
        device=device,
        model_to_use=model_to_use,
        random_state=random_state,
    )

    # Create reuseable search class
    cross_val_training = BlackBoxAugmentationSearcher(
        training_dir,
        model,
        objective_factory=SEARCH_METRICS[metric],
        sampler=SEARCH_ALGOS[sampler],
        subpolicies_length=subpolicies_length,
        policy_size_per_fold=policy_size_per_fold,
        model_params=model_params,
        shared_callbacks=shared_callbacks,
        balanced_loss=True,  # Not settable for now
        monitor='valid_bal_acc_best',  # Not settable for now
        log_tensorboard=False,  # Not settable for now
        should_checkpoint=False,
        should_load_state=False,
        classwise=classwise,
        classes=classes,
        random_state=random_state,
        **kwargs
    )

    # Split the total number of trials in steps and start counters
    steps = split_trials_in_steps(n_trials, eval_step)
    curr_n_trials = 0
    curr_duration = 0

    # Load results if they exist and infer step_idx, curr_search_epochs, etc.
    first_step, last_results = load_retrain_results_if_exist(training_dir)
    if last_results is not None:
        curr_n_trials = last_results["tot_trials"]
        curr_duration = last_results["tot_search_duration"]

    for step_idx, step_n_trials in enumerate(steps):
        # Add step offset in case we are continuing a training
        step_idx += first_step
        print(f"=== STEP {step_idx} ===")

        # Continue policy search with new step budget
        start_step = time()
        cross_val_training.search_policy(
            windows_dataset=windows_dataset,
            epochs=epochs,
            data_ratio=data_ratio,
            grouped_subset=grouped_subset,
            n_jobs=n_jobs,
            verbose=verbose,
            ordered_ch_names=ordered_ch_names,
            sfreq=sfreq,
            n_trials=int(step_n_trials),
            pretrain_base_path=pretrain_base_path,
        )
        step_duration = time() - start_step

        # Increment time and trials counters
        curr_duration += step_duration
        curr_n_trials += step_n_trials

        # With the timer paused, assess the learned policy up to now
        results = cross_val_training.learning_curve(
            windows_dataset=windows_dataset,
            epochs=epochs,
            data_ratios=data_ratio,
            grouped_subset=grouped_subset,
            n_jobs=n_jobs,
            verbose=verbose,
        )

        # Add search duration and step information to results and save it
        results['step_idx'] = step_idx
        results['step_search_duration'] = step_duration
        results['tot_search_duration'] = curr_duration
        results['tot_trials'] = curr_n_trials
        save_path = join(training_dir, 'search_perf_results.csv')
        if step_idx == 0:
            results.to_csv(save_path, index=False)
        else:
            results.to_csv(
                save_path,
                index=False,
                header=False,
                mode='a'  # Append to the end of the file
            )


def handle_transforms_pool_subsetting(args):
    if args.best_transf_only:
        if args.tr_subset:
            return GLOBAL_TRANSFORMS_SUBSET
        if args.dataset == "mass":
            return MASS_BEST_TRANSFORMS
        if args.dataset == "edf":
            return PHYSIONET_BEST_TRANSFORMS
    # else:
    #     if args.dataset == "edf":
    #         return PHYSIONET_ALLOWED_TRANSFORMS
    return None


def training_params_from_args():
    parser = make_args_parser()
    parser.add_argument(
        "-l", "--subpolicy_length",
        type=int,
        default=2,
        help="Length of transforms sequence to sample."
    )

    parser.add_argument(
        "-t", "--n_trials",
        type=int,
        default=200,
        help="Number of subpolicies trials."
    )

    parser.add_argument(
        "--eval_step", "--step",
        type=int, default=10,
        help="Number of trials between two policy evaluations."
    )

    parser.add_argument(
        "-pl", "--policy_length",
        type=int,
        default=5,
        help="Number of subpolicies in a policy."
    )

    parser.add_argument(
        "--best_transf_only",
        action="store_true",
        help="Whether to only use known best transforms.",
    )

    parser.add_argument(
        "--tr_subset",
        action="store_true",
        help="Whether to only use 5 best transforms common to both datasets.",
    )

    parser.add_argument(
        "-s", "--sampler",
        default="random-search",
        help="Search algorithm to use."
    )

    parser.add_argument(
        "-mt", "--metric",
        default="autoaug",
        help="Search metric to use."
    )

    parser.add_argument(
        "--pretrain_path",
        help="Path to root of pretrained model to use for warmstarting."
             "(just for the model, not the policy)"
    )
    parser.add_argument(
        "--classwise",
        action="store_true",
        help="Whether to carry classwise policy search."
    )
    args = parser.parse_args()

    windows_dataset, ch_names, sfreq = handle_dataset_args(args)
    classes = np.unique(np.hstack([ds.y for ds in windows_dataset.datasets]))

    # XXX: Should I fix this so that ChannelSym is forebidden for EDF
    # XXX: (Same for diff_auto_aug)
    transforms_collection = handle_transforms_pool_subsetting(args)

    parameters = {
        'training_dir': args.training_dir,
        'windows_dataset': windows_dataset,
        'epochs': args.epochs,
        'sfreq': sfreq,
        'sampler': args.sampler,
        'metric': args.metric,
        'pretrain_base_path': args.pretrain_path,
        'device': args.device,
        'lr': args.lr,
        'batch_size': args.batch_size,
        'num_workers': args.num_workers,
        'random_state': args.random_state,
        'early_stop': args.early_stop,
        'n_folds': args.nfolds,
        'train_size_over_valid': args.train_size_over_valid,
        'model_to_use': args.model,
        'data_ratio': args.data_ratio,
        'grouped_subset': args.grouped_subset,
        'n_jobs': args.n_jobs,
        'ordered_ch_names': ch_names,
        'subpolicies_length': args.subpolicy_length,
        'n_trials': args.n_trials,
        'eval_step': args.eval_step,
        'policy_size_per_fold': args.policy_length,
        'transforms_family': transforms_collection,
        "classwise": args.classwise,
        "classes": classes,
    }

    if args.config:
        config = read_config(args.config)
        parameters.update(config["split"])
        parameters.update(config["discrete"])
        parameters.update(config["policy"])
        parameters.update(config["training"])
    return parameters


if __name__ == '__main__':
    training_params = training_params_from_args()

    evaluate_discrete_policy_search(**training_params)
