from collections import defaultdict
import inspect
from typing import Tuple
from copy import deepcopy
import os
from os.path import join
from time import time
from pathlib import Path
from datetime import datetime
from shutil import copy

import numpy as np
import pandas as pd
from sklearn.metrics import balanced_accuracy_score
from sklearn.utils.validation import check_random_state
import torch
from torch import nn, Tensor
from torch import autograd
from torch.utils.data import DataLoader, Subset
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable

from braindecode.augmentation.base import Compose
from braindecode.augmentation import transforms

from eeg_augment.training_utils import prepare_training, make_args_parser,\
    handle_dataset_args, make_class_proportions_tensor, stratified_split,\
    get_groups, grouped_split, read_config, set_random_seeds
from eeg_augment.auto_augmentation import BaseAugmentationSearcher,\
    AugmentationPolicy, ClasswiseSubpolicy, load_retrain_results_if_exist
from eeg_augment.diff_aug.base import DiffAugmentationPolicy,\
    DiffClasswisePolicy, diff_policy_to_standard_policy,\
    make_diff_transforms_subset
from eeg_augment.diff_aug.dada import DADAPolicy
from eeg_augment.utils import get_global_rngs_states, set_global_rngs_states
from eeg_augment.utils import estimate_Hessian_largest_eig


def _concat(xs):
    return torch.cat([x.view(-1) for x in xs])


def _transfer_weights_to_cw_policy(classwise_policy, non_cw_policy):
    with torch.no_grad():
        for subpolicy, cw_subpolicy in zip(
            non_cw_policy.diff_subpolicies, classwise_policy.diff_subpolicies
        ):
            for cls_subpol in cw_subpolicy.subpolicies_per_class.values():
                for stage, cw_stage in zip(
                    subpolicy.stages, cls_subpol.stages
                ):
                    cw_stage._weights = nn.Parameter(stage._weights.clone())
                    for transform, cw_transform in zip(
                        stage.transforms, cw_stage.transforms
                    ):
                        cw_transform._probability = nn.Parameter(
                            transform.probability.clone()
                        )
                        if transform.magnitude is not None:
                            cw_transform._magnitude = nn.Parameter(
                                transform.magnitude.clone()
                            )
    return classwise_policy


class GradientBaseSearcher(BaseAugmentationSearcher):
    def __init__(
        self,
        training_dir,
        model,
        subpolicies_length=2,
        policy_size_per_fold=5,
        transforms_family=None,
        use_transforms_subset=False,
        policy_lr=None,
        classwise=False,
        grad_est=None,
        dada=False,
        **kwargs
    ):
        super().__init__(
            training_dir,
            model,
            subpolicies_length=subpolicies_length,
            policy_size_per_fold=policy_size_per_fold,
            transforms_family=transforms_family,
            **kwargs
        )
        self.use_transforms_subset = use_transforms_subset
        self.classwise = classwise
        self.dada = dada
        self.grad_est = grad_est
        # Used to know whether the current learned policy was updated for each
        # fold, which allows to skip fold retrain in many cases
        self.retrain_fold = {i: True for i in range(1, self.n_folds + 1)}
        if policy_lr is None:
            policy_lr = self.model_params["lr"]
        self.policy_lr = policy_lr

    def convert_policy(self, step_idx, log=True):
        """ Allows to convert the learned DiffAugmentationPolicy attributes for
        each fold into equivalent standard AugmentationPolicy objects

        Very handy to be able to call `learning_curve` afetr and assess the
        learned policy in a common training pipeline.
        """
        # We start a new RNG with the initial seed to instantiate policies
        # and transforms used for the retraining. This leads to the same random
        # transforms for every new retraining, avoinding some reproducibility
        # issues
        retrain_rng = check_random_state(self.retraining_seed)
        child_states = retrain_rng.randint(
            0, 2**32, len(self.learned_policies)
        )

        standard_policies = dict()
        for (fold, policy), seed in zip(
            self.learned_policies.items(),
            child_states
        ):
            standard_policies[fold] = diff_policy_to_standard_policy(
                policy, check_random_state(seed)
            )
        self.learned_policies = standard_policies

        # Save the learned policies up to this step on a csv
        if log:
            self._log_policy(step_idx)

    def _log_policy(self, step_idx):
        """ Allows to save learned policy structure in a csv"""
        structure_by_fold = list()
        for fold, policy in self.learned_policies.items():
            assert isinstance(policy, AugmentationPolicy), (
                "self.learned_policies has to be converted before logging it. "
                "convert_policy() should be called before this method."
            )
            policy_struc = policy.get_structure()
            policy_struc["fold"] = fold
            policy_struc["step_idx"] = step_idx
            structure_by_fold.append(policy_struc)
        structure_log_path = Path(self.training_dir) / 'learned_policies.csv'
        if step_idx == 0:
            mode = "w"
            header = True
        else:
            mode = "a"
            header = False
        pd.concat(structure_by_fold, ignore_index=True).to_csv(
            structure_log_path,
            header=header,
            index=False,
            mode=mode,
        )

    def make_search_valid_loaders(self, search_set, valid_set, test_set,
                                  n_classes=None):
        """Creates the training, validation and test DataLoader objects from
        the corresponding datasets, using the settings from self.model_params

        Parameters
        ----------
        search_set : torch.utils.data.Dataset
            Training set
        valid_set : torch.utils.data.Dataset
            Validationset
        test_set : torch.utils.data.Dataset
            Test set

        Returns
        -------
        tuple
            Tuple containing the three DataLoaders.
        """
        # So that the effective search batch size keeps the same in a classwise
        # setting
        batch_size = self.model_params["batch_size"]
        if self.classwise:
            if n_classes is None:
                n_classes = len(np.unique([y for _, y, _ in search_set]))
            batch_size = batch_size * n_classes

        train_loader = DataLoader(
            search_set,
            batch_size=batch_size,
            pin_memory=self.model_params.get(
                "iterator_train__pin_memory",
                False
            ),
            num_workers=self.model_params.get(
                'iterator_train__num_workers',
                0
            ),
            worker_init_fn=self.model_params.get(
                "iterator_train__worker_init_fn",
                None
            ),
            multiprocessing_context=self.model_params.get(
                "iterator_train__multiprocessing_context",
                None
            ),
            drop_last=True,
        )

        valid_loader = DataLoader(
            valid_set,
            batch_size=min(1000, len(valid_set)),
            pin_memory=self.model_params.get(
                "iterator_train__pin_memory",
                False
            ),
            num_workers=self.model_params.get(
                'iterator_train__num_workers',
                0
            ),
            worker_init_fn=self.model_params.get(
                "iterator_train__worker_init_fn",
                None
            ),
            multiprocessing_context=self.model_params.get(
                "iterator_train__multiprocessing_context",
                None
            ),
            drop_last=True,
        )

        test_loader = DataLoader(
            test_set,
            batch_size=min(1000, len(test_set)),
            pin_memory=self.model_params.get(
                "iterator_train__pin_memory",
                False
            ),
            num_workers=self.model_params.get(
                'iterator_train__num_workers',
                0
            ),
            worker_init_fn=self.model_params.get(
                "iterator_train__worker_init_fn",
                None
            ),
            multiprocessing_context=self.model_params.get(
                "iterator_train__multiprocessing_context",
                None
            ),
            drop_last=True,
        )
        return train_loader, valid_loader, test_loader

    def load_trained_classifier(self, pretrain_base_path, fold, subset_ratio):
        """Method used to warmstart a copy of self.model with pretrained
        weights for density matching methods (basically Fast AutoAugment only)

        Parameters
        ----------
        pretrain_base_path : str
            Path to root folder of the trained model.
        fold : int
            Fold to load the learned weights from.
        subset_ratio : float
            Proportion of the training set to load learnd weights from.

        Returns
        -------
        torch.Model
            The pretrained copy of self.model.
        """
        warmstart_path = self._get_warmstart_path(
            pretrain_base_path, fold, subset_ratio
        )
        trained_model = deepcopy(self.model)
        if warmstart_path is not None:
            trained_model.load_state_dict(
                torch.load(
                    join(warmstart_path, "params.pt"),
                    map_location=self.device,
                )
            )
        trained_model.to(self.device)
        return trained_model

    def _make_classwise_policy(
        self,
        ordered_ch_names,
        sfreq,
        classes,
        diff_transforms,
        random_state,
    ):
        random_state_obj = check_random_state(random_state)
        if classes is None:
            # XXX: Classes should be output by functions creating the
            # datasets and passes together with channels and frequency
            # (actually, they should be attributes of braindecode datasets
            # ideally)
            classes = [0, 1, 2, 3, 4]
        return DiffClasswisePolicy(
            n_subpolicies=self.policy_size_per_fold,
            subpolicy_len=self.subpolicies_length,
            ch_names=ordered_ch_names,
            sfreq=sfreq,
            random_state=random_state_obj,
            classes=classes,
            grad_est=self.grad_est,
        ).to(self.device)

    def _make_standard_policy(
        self,
        ordered_ch_names,
        sfreq,
        diff_transforms,
        random_state,
    ):
        random_state_obj = check_random_state(random_state)
        return DiffAugmentationPolicy(
            n_subpolicies=self.policy_size_per_fold,
            subpolicy_len=self.subpolicies_length,
            ch_names=ordered_ch_names,
            sfreq=sfreq,
            random_state=random_state_obj,
            grad_est=self.grad_est,
        ).to(self.device)

    def create_policy_and_optimizer(self, ordered_ch_names, sfreq,
                                    random_state, classes=None):
        """Creates and returns new policy model and its optimizer, together
        with a copy of the model to be used to store the best model
        """
        diff_transforms = None
        random_state_obj = check_random_state(random_state)
        if self.use_transforms_subset:
            diff_transforms = make_diff_transforms_subset(
                sfreq=sfreq,
                ordered_ch_names=ordered_ch_names,
                random_state=random_state_obj,
            )

        if self.classwise:
            trainable_policy = self._make_classwise_policy(
                ordered_ch_names=ordered_ch_names,
                sfreq=sfreq,
                random_state=random_state_obj,
                classes=classes,
                diff_transforms=diff_transforms,
            )
        else:
            if self.dada:
                random_state_obj = check_random_state(random_state)
                trainable_policy = DADAPolicy(
                    n_subpolicies=self.policy_size_per_fold,
                    subpolicy_len=self.subpolicies_length,
                    ch_names=ordered_ch_names,
                    sfreq=sfreq,
                    random_state=random_state_obj,
                ).to(self.device)
            else:
                trainable_policy = self._make_standard_policy(
                    ordered_ch_names=ordered_ch_names,
                    sfreq=sfreq,
                    random_state=random_state_obj,
                    diff_transforms=diff_transforms,
                )

        optimizer = self.model_params["optimizer"](
            trainable_policy.parameters(),
            lr=self.policy_lr,
            betas=(0., 0.999)
        )
        # This deepcopy does not mess with the rng because the latter is reset
        # in the convert_policy method
        return trainable_policy, optimizer, deepcopy(trainable_policy)

    def _epoch_from_chkpt(self, checkpoint):
        """Extracts epoch number from checkpoint file name"""
        return int(checkpoint.as_posix().split("_")[-2][1:])

    def checkpoint_stuff(
        self,
        to_checkpoint,
        checkpoints_path,
        ckpt_prefix,
        epoch
    ):
        """ Checkpoints everything from a dictionary (models, optimizers, rngs)
        """
        checkpoint_dict = {}
        for obj_name, obj in to_checkpoint.items():
            if obj_name == "rng_states":
                # special case of rngs (object and globals)
                checkpoint_dict[obj_name] = {
                    "random_state_obj": obj["random_state_obj"].get_state(),
                    "global_rng_states": obj["global_rng_states"],
                }
            else:
                # All other objects: model, policy, optimizers, ...
                checkpoint_dict[obj_name] = obj.state_dict()
        check_name = (
            checkpoints_path /
            f"{ckpt_prefix}_e{epoch}_{datetime.now():%y-%m-%d %H:%M:%S}.pt"
        )
        torch.save(checkpoint_dict, check_name)

    def load_checkpoints(self, checkpoints_path, ckpt_prefix, to_checkpoint):
        """ Loads all checkpointed elements to continue training from where
        it was stopped and returns epoch to restart from
        """
        # Get all checkpoints sorted by modification time, oldest first.
        existing_checkpoints = sorted(
            checkpoints_path.glob(f'{ckpt_prefix}*.pt'),
            key=lambda t: t.stat().st_mtime
        )
        if len(existing_checkpoints) > 0:
            # If one exists, load the one modified last
            checkpoint_dict = torch.load(
                existing_checkpoints[-1], map_location=self.device,
            )

            # Load checkpoint for all objects listed in to_checkpoint
            for obj_name, obj in to_checkpoint.items():
                if obj_name == "rng_states":
                    # special case of rngs
                    # set RandomState object state from checkpoint
                    obj["random_state_obj"].set_state(
                        checkpoint_dict[obj_name]["random_state_obj"]
                    )
                    # set global seeds record from checkpoint
                    obj["global_rng_states"] = checkpoint_dict[obj_name][
                        "global_rng_states"
                    ]
                else:
                    # All other objects: model, policy, optimizers, ...
                    obj.load_state_dict(checkpoint_dict[obj_name])
                try:
                    obj.to(self.device)
                except AttributeError:
                    # Will fail for optimizer and rngs, which have no method to
                    continue
            return self._epoch_from_chkpt(existing_checkpoints[-1]) + 1
        return 1

    def load_history_if_exist(self, search_history_path, epoch_to_fetch=None):
        """ Loads best_valid_acc from history if one exist in the correct path
        and infers current step
        """
        best_valid_acc = 0
        step_idx = 0
        idx = -1
        if search_history_path.exists():
            prev_results = pd.read_csv(search_history_path)
            # Fetch row index corresponding to epoch (useful for warmstarting)
            if epoch_to_fetch is not None:
                matches = prev_results.index[
                    prev_results["epoch"] == epoch_to_fetch
                ].tolist()
                assert len(matches) == 1,\
                    "more than one row matching desired epoch found !"
                idx = matches[0]
            best_valid_acc = prev_results["best_valid_bal_acc"].values[idx]
            step_idx = prev_results["step_idx"].values[idx] + 1
        return step_idx, best_valid_acc

    def load_retrain_results_if_exist(self, fold):
        """ Loads last line of retraining results if one exist in the correct
        path and inferes current step
        """
        step_idx, last_fold_results = load_retrain_results_if_exist(
            self.training_dir, fold=fold
        )
        if last_fold_results is None:
            last_fold_results = {
                "best_valid_bal_acc": 0,
                "best_test_bal_acc": 0
            }
        return step_idx, last_fold_results

    def instantiate_criteria(self, search_set, valid_set, test_set):
        """Creates criteria objects for searching, validating and testing from
        the classes stored in self.model_params["criterion"]

        Computes the corresponding class weights for each dataset when
        self.balanced_loss is True

        Parameters
        ----------
        search_set : torch.utils.data.Dataset
            Training set
        valid_set : torch.utils.data.Dataset
            Validation set
        test_set : torch.utils.data.Dataset
            Test set

        Returns
        -------
        tuple
            Tuple containing the criteria objects for searching, validating and
            testing.
        """
        search_class_weights = make_class_proportions_tensor(
            search_set,
            self.balanced_loss,
            self.device,
        )

        valid_class_weights = make_class_proportions_tensor(
            valid_set,
            self.balanced_loss,
            self.device,
        )

        test_class_weights = make_class_proportions_tensor(
            test_set,
            self.balanced_loss,
            self.device,
        )

        search_criterion = self.model_params["criterion"](
            weight=search_class_weights
        )
        valid_criterion = self.model_params["criterion"](
            weight=valid_class_weights,
            reduction='sum',
        )
        test_criterion = self.model_params["criterion"](
            weight=test_class_weights,
            reduction='sum',
        )
        return search_criterion, valid_criterion, test_criterion

    def set_tensorboard_up(self, fold, subset_ratio):
        """ When applicable, creates logging folders and writer"""
        if self.log_tensorboard:
            main_logs_path = join(self.training_dir, "logs")
            fold_logs_path = join(
                main_logs_path,
                f'fold{fold}of{self.n_folds}',
                f'subset_{subset_ratio}_samples',
            )
            os.makedirs(fold_logs_path, exist_ok=True)
            return SummaryWriter(fold_logs_path)
        return None

    def search_policy(
        self,
        windows_dataset,
        epochs,
        data_ratio=None,
        grouped_subset=True,
        n_jobs=1,
        verbose=False,
        **kwargs
    ):
        """Search for best augmentation policy

        Parameters
        ----------
        windows_dataset : torch.data.utils.Dataset
            Dataset to use for search, validation and testing.
        epochs : int
            Number of epochs of search.
        data_ratio : float | None, optional
            Float between 0 and 1 or None. Will be used to build a
            subset of the cross-validated training sets (valid and test sets
            are conserved). Omitting it or setting it to None, is equivalent to
            setting it to [1.] (using the whole training set). By default None.
        grouped_subset : bool, optional
            Whether to compute training subsets taking groups (subjects) into
            account or not. When False, stratified spliting will be used to
            build the subsets. By default True.
        n_jobs : int, optional
            Number of workers to use for parallelizing across splits. By
            default 1.
        verbose : bool, optional
            By default False.
        """
        assert isinstance(data_ratio, float) or data_ratio is None,\
            "Only a single data_ratio value is supported for now."
        self._epochs = epochs
        folds_policies_and_retrain_orders = self._crossval_apply(
            windows_dataset,
            epochs,
            function_to_apply=self._search_in_fold,
            data_ratios=data_ratio,
            grouped_subset=grouped_subset,
            n_jobs=n_jobs,
            verbose=verbose,
            **kwargs
        )
        folds_policies_tuples, should_retrain = list(
            zip(*folds_policies_and_retrain_orders)
        )
        self.learned_policies = dict(folds_policies_tuples)
        self.retrain_fold = dict(should_retrain)

    def _fit_and_score(
        self,
        split,
        warmstart_base_path=None,
        warmstart_epoch=None,
        **kwargs,
    ):
        """Train and tests a copy of self.model on the desired split using the
        learned policy

        Parameters
        ----------
        split : tuple
            Tuple containing the fold index, the training set proportion and
            the indices of the training, validation and test set.
        epochs : int
            Maximum number of epochs for the training.
        windows_dataset : torch.utils.data.Dataset
            Dataset that will be split and used for training, validation and
            tetsing.
        model_params : dict
            Modified copy of self.model_params.
        warmstart_base_path : str | None, optional
            Path to folder used for warmstarting.
        warmstart_epoch : int | None, optional
            Epoch to use for warmstarting.
        random_state : int | None
            Seed to use for RNGs.

        Returns
        -------
        dict
            Dictionary containing the balanced accuracy, the loss, the kappa
            score and the confusion matrix for the training, validationa and
            test sets.
        """
        fold, subset_ratio, _, _, _ = split
        # Look for existing results file and if there is, read step_idx
        # and best metrics from it (for this fold only ofc)
        step_idx, last_fold_results = self.load_retrain_results_if_exist(fold)
        # But load from warmstart when applicable
        if (
            step_idx == 0 and
            warmstart_base_path is not None and
            warmstart_epoch is not None
        ):
            # XXX: add a warning if warm start path does not exists
            step_idx, last_fold_results = load_retrain_results_if_exist(
                warmstart_base_path, epoch=warmstart_epoch, fold=fold,
            )

        if self.retrain_fold[fold] is True:
            # If fold should be retrained (because best policy changed), call
            # parent retraining method
            fold_results = super()._fit_and_score(split, **kwargs)
            # If valid acc obtained is better than previous best value,
            # update it
            for ds in ["valid", "test"]:
                curr_acc = fold_results[f"{ds}_bal_acc"]
                best_acc = last_fold_results[f"best_{ds}_bal_acc"]
                if curr_acc > best_acc:
                    best_acc = curr_acc
                fold_results[f"best_{ds}_bal_acc"] = best_acc
        else:
            print(f">>> Skipping retraining for fold {fold}, step {step_idx}")
            # Otherwise, just return previous results
            fold_results = last_fold_results
            if isinstance(fold_results, pd.Series):
                fold_results = fold_results.to_dict()

        # Set tensorboard log
        writer = self.set_tensorboard_up(fold=fold, subset_ratio=subset_ratio)

        # Write retraining metrics on it to be able to follow the training
        if writer is not None:
            set_names = ["train", "valid", "test"]
            for ds in set_names:
                writer.add_scalar(
                    f"Retraining-loss/{ds}",
                    fold_results[f"{ds}_loss"],
                    step_idx,
                )
                writer.add_scalar(
                    f"Retraining-bal-acc/{ds}",
                    fold_results[f"{ds}_bal_acc"],
                    step_idx,
                )
                if ds != "train":
                    writer.add_scalar(
                        f"Retraining-bal-acc/best-{ds}",
                        fold_results[f"best_{ds}_bal_acc"],
                        step_idx,
                    )
        return fold_results


class DiffDensityMatching(GradientBaseSearcher):
    def _train(self, policy, model, device, train_loader, optimizer, epoch,
               criterion, log_interval, writer):
        # Set the policy in training mode and the model in eval mode.
        policy.train()
        model.eval()
        ground_truth = list()
        predictions = list()
        train_loss = 0
        for batch_idx, (data, target, _) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            tr_data, _ = policy(data, target)
            output = model(tr_data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            ground_truth += target.tolist()
            pred = output.clone().detach().argmax(dim=1)
            predictions += pred.tolist()
            train_loss += loss.item()
            if batch_idx % log_interval == 0:
                print(
                    f'Train Epoch: {epoch}'
                    f' [{batch_idx * len(data)}/{len(train_loader.dataset)}'
                    f' ({100. * batch_idx / len(train_loader):.0f}%)]'
                    f'\tLoss: {loss.item():.6f}'
                )
                n_iter = epoch * len(train_loader) + batch_idx
                iter_bal_acc = balanced_accuracy_score(
                    target.tolist(), pred.tolist()
                )
                if writer is not None:
                    writer.add_scalar('Loss/batch-train', train_loss, n_iter)
                    writer.add_scalar(
                        'Accuracy/batch-train',
                        iter_bal_acc,
                        epoch
                    )
                    writer.add_histogram(
                        "Parameters/probabilities",
                        policy.all_probabilities,
                        n_iter
                    )
                    writer.add_histogram(
                        "Parameters/magnitudes",
                        policy.all_magnitudes,
                        n_iter
                    )
                    writer.add_histogram(
                        "Parameters/weights",
                        policy.all_weights,
                        n_iter
                    )
                    writer.add_histogram(
                        "Grads/probabilities",
                        policy.all_prob_grads,
                        n_iter
                    )
                    writer.add_histogram(
                        "Grads/magnitudes",
                        policy.all_mag_grads,
                        n_iter
                    )
                    writer.add_histogram(
                        "Grads/weights",
                        policy.all_weight_grads,
                        n_iter
                    )

        train_loss /= (batch_idx + 1)
        bal_acc = balanced_accuracy_score(ground_truth, predictions)
        print(
            f'\nTraining set -- Running loss: {train_loss:.4f},'
            f' Balanced Accuracy:  {bal_acc * 100:.1f}%\n'
        )
        return train_loss, bal_acc

    def _test(self, policy, model, device, test_loader, criterion,
              valid=False):
        set_name = 'Test set'
        if valid:
            set_name = 'Validation set'
        model.eval()
        policy.eval()
        test_loss = 0
        ground_truth = list()
        predictions = list()
        with torch.no_grad():
            for data, target, _ in test_loader:
                data, target = data.to(device), target.to(device)
                tr_data, _ = policy(data, target)
                output = model(tr_data)
                # sum up batch loss
                test_loss += criterion(output, target).item()
                # get the index of the max log-probability
                pred = output.argmax(dim=1)
                predictions += pred.tolist()
                ground_truth += target.tolist()

        test_loss /= len(test_loader.dataset)
        bal_acc = balanced_accuracy_score(ground_truth, predictions)

        print(
            f'\n{set_name} -- Average loss: {test_loss:.4f},'
            f' Balanced Accuracy: {bal_acc * 100:.1f}%\n'
        )
        return test_loss, bal_acc

    def _search_in_fold(
        self,
        split,
        random_state,
        epochs,
        windows_dataset,
        pretrain_base_path,
        ordered_ch_names,
        sfreq,
        train_policy_on='same',
        **kwargs
    ):
        # Get train and valid splits from dataset
        fold, subset_ratio, train_subset_idx, valid_idx, test_idx = split

        # Seed global RNGs and local RndomState object for this fold
        self._init_global_and_specific_rngs(
            fold=fold,
            random_state=random_state
        )

        # Init retrain bool for this fold
        retrain_this_fold = False

        # The search is made on a subset of the valid, which is split using
        # the same length as the training set
        # XXX: This makes no sense any more and should be removed
        if train_policy_on == "same":
            targets = np.array([y for _, y, _ in windows_dataset])
            search_idx, new_valid_idx = stratified_split(
                indices=valid_idx,
                ratio=len(train_subset_idx),
                targets=targets[valid_idx],
                random_state=self.splitting_random_state,
            )
        elif isinstance(train_policy_on, (float, int)):
            groups = get_groups(windows_dataset)
            search_idx, new_valid_idx = grouped_split(
                indices=valid_idx,
                ratio=train_policy_on,
                groups=groups[valid_idx],
                random_state=self.splitting_random_state,
            )
            print("len(search_idx): ", len(search_idx))
            print("len(new_valid_idx): ", len(new_valid_idx))
        else:
            raise ValueError(
                "train_policy_on can be either 'same' an int or"
                f" a float. Got {type(train_policy_on)}: {train_policy_on}"
            )
        search_set = Subset(windows_dataset, search_idx)
        valid_set = Subset(windows_dataset, new_valid_idx)
        test_set = Subset(windows_dataset, test_idx)

        print(
            f"---------- Fold {fold} out of {self.n_folds} |",
            f"Training size: {len(train_subset_idx)} ----------"
        )

        # Make loaders
        (
            train_loader,
            valid_loader,
            test_loader
        ) = self.make_search_valid_loaders(search_set, valid_set, test_set)

        # Load existing history and fetch best accuracy and step index from it
        fold_path = Path(self.training_dir) / f'fold{fold}of{self.n_folds}'
        subset_path = fold_path / f'subset_{subset_ratio}_samples'
        search_history_path = subset_path / "search_history.csv"
        step_idx, best_valid_acc = self.load_history_if_exist(
            search_history_path
        )

        # Create training and validation criteria
        (
            train_criterion,
            valid_criterion,
            test_criterion
        ) = self.instantiate_criteria(search_set, valid_set, test_set)

        # Set tensorboard writer (in a central folder)
        writer = self.set_tensorboard_up(fold=fold, subset_ratio=subset_ratio)

        # Load pretrained model and turn off its gradients
        trained_model = self.load_trained_classifier(
            pretrain_base_path=pretrain_base_path,
            fold=fold,
            subset_ratio=subset_ratio,
        )
        trained_model.requires_grad = False

        # Create a new DiffAugmentationPolicy object and its optimizer
        policy, optimizer, best_policy = self.create_policy_and_optimizer(
            ordered_ch_names=ordered_ch_names,
            sfreq=sfreq,
            random_state=self.random_states[fold]["random_state_obj"],
        )

        to_checkpoint = {
            'policy': policy,
            'optimizer': optimizer,
            'best_policy': best_policy,
        }

        # Setup checkpointing path, to save parameters for every epochs
        checkpoints_path = subset_path / "checkpoints"
        os.makedirs(checkpoints_path, exist_ok=True)
        ckpt_prefix = 'diff_density_matching'

        # Set global and specific RNGs for this fold
        self._set_global_rngs_from_previous_calls(fold=fold)

        # If checkpoints exist in the correct folder, load the last one to
        # continue the policy training
        start_epoch = self.load_checkpoints(
            checkpoints_path, ckpt_prefix, to_checkpoint
        )

        results = list()
        start = time()

        log_interval = int(len(train_loader) / 10)
        if log_interval == 0:
            log_interval = 1

        for epoch in range(start_epoch, start_epoch + epochs):
            tr_loss, tr_acc = self._train(
                policy=policy,
                model=trained_model,
                device=self.device,
                train_loader=train_loader,
                optimizer=optimizer,
                epoch=epoch,
                criterion=train_criterion,
                log_interval=log_interval,
                writer=writer,
            )
            valid_loss, valid_acc = self._test(
                policy=policy,
                model=trained_model,
                device=self.device,
                test_loader=valid_loader,
                criterion=valid_criterion,
                valid=True,
            )
            test_loss, test_acc = self._test(
                policy=policy,
                model=trained_model,
                device=self.device,
                test_loader=test_loader,
                criterion=test_criterion,
                valid=True,
            )
            time_since_start = time() - start

            # Store best valid accuracy if better then previous values
            # and overwrite best model (output by the search)
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                to_checkpoint["best_policy"] = best_policy = deepcopy(policy)
                # Ask for retraining only if improvement observed
                retrain_this_fold = True
            else:
                print(
                    f">> No validation improvement at this epoch ({epoch}) "
                    f"for fold {fold}"
                )

            if writer is not None:
                writer.add_scalar('Loss/valid', valid_loss, epoch)
                writer.add_scalar('Loss/test', test_loss, epoch)
                writer.add_scalar('Accuracy/search', tr_acc, epoch)
                writer.add_scalar('Accuracy/valid', valid_acc, epoch)
                writer.add_scalar('Accuracy/test', test_acc, epoch)
                writer.add_scalar('Accuracy/best-valid', best_valid_acc, epoch)
                writer.add_scalar('time', time_since_start, epoch)

            results.append({
                "fold": fold,
                "n_folds": self.n_folds,
                "step_idx": step_idx,
                "epoch": epoch,
                "time": time_since_start,
                "search_loss": tr_loss,
                "valid_loss": valid_loss,
                "test_loss": test_loss,
                "search_bal_acc": tr_acc,
                "valid_bal_acc": valid_acc,
                "test_bal_acc": test_acc,
                "best_valid_bal_acc": best_valid_acc,
            })

            # Create a checkpoint with all objects listed in to_checkpoint
            self.checkpoint_stuff(
                to_checkpoint=to_checkpoint,
                checkpoints_path=checkpoints_path,
                ckpt_prefix=ckpt_prefix,
                epoch=epoch,
            )

        if step_idx == 0:  # Start new csv if none is found
            mode = "w"
            header = True
        else:
            mode = "a"
            header = False
        pd.DataFrame(results).to_csv(
            search_history_path,
            header=header,
            index=False,
            mode=mode,
        )

        # Store in memory the current global RNG states for future search steps
        self._save_current_global_rng_states(fold=fold)

        # Return the best model found up to here, based on (density match)
        # validation accuracy
        return (
            (fold, best_policy.requires_grad_(False)),
            (fold, retrain_this_fold)
        )


class ClassifierCritic(nn.Module):
    def __init__(self, base_module: nn.Module, device):
        super(ClassifierCritic, self).__init__()
        self.base_model = base_module
        num_features = self.base_model.fc[-1].in_features
        num_class = self.base_model.fc[-1].out_features
        self.base_model.fc = nn.Identity()
        self.classifier = nn.Linear(num_features, num_class).to(device)
        self.discriminator = nn.Sequential(
            nn.Linear(num_features, num_features),
            nn.ReLU(),
            nn.Linear(num_features, 1)
        ).to(device)

    def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
        x = self.base_model(input)
        return (
            self.classifier(x).to(x.device),
            self.discriminator(x).view(-1).to(x.device)
        )


class FasterAutoAugment(GradientBaseSearcher):
    def __init__(
        self,
        training_dir,
        model,
        subpolicies_length=2,
        policy_size_per_fold=5,
        transforms_family=None,
        penalty_weight=0.1,
        gp_factor=10,
        **kwargs
    ):
        super().__init__(
            training_dir,
            model,
            subpolicies_length=subpolicies_length,
            policy_size_per_fold=policy_size_per_fold,
            transforms_family=transforms_family,
            **kwargs
        )
        self.gp_factor = gp_factor
        self.penalty_weight = penalty_weight

    def gradient_penalty(
        self,
        classifier_critic,
        real: Tensor,
        fake: Tensor
    ) -> Tensor:
        alpha = real.new_empty(real.size(0), 1, 1).uniform_(0, 1)
        interpolated = alpha * real + (1 - alpha) * fake
        interpolated.requires_grad_()
        _, output = classifier_critic(interpolated)
        grad = torch.autograd.grad(
            outputs=output,
            inputs=interpolated,
            grad_outputs=torch.ones_like(output),
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        return (grad.norm(2, dim=1) - 1).pow(2).mean()

    def _optim_step(self, data, policy, classifier_critic, policy_optim,
                    critic_optim, criterion):
        # First, the batch is split in two
        input, target = data
        b = input.size(0) // 2
        drop = input.size(0) % 2
        a_input, a_target = input[:b], target[:b]
        if drop > 0:
            n_input, n_target = input[b:-drop], target[b:-drop]
        else:
            n_input, n_target = input[b:], target[b:]
        ones = n_input.new_tensor(1.0)

        # TRAIN DISCRIMINATOR:
        # Then, we compute the critic loss and update the critic and classifier
        # weights

        # We turn on gradients computation for it
        classifier_critic.requires_grad_(True)
        classifier_critic.zero_grad()

        # Do a forward pass on the unaugmented batch-half, which outputs the
        # classifier output and the corresponding part of the WGan loss
        output, n_output = classifier_critic(n_input)
        d_n_loss = n_output.mean()

        loss = self.penalty_weight * criterion(output, n_target)  # L(f(X'),y')

        # Then we augment the other half of the batch...
        augmented, _ = policy(a_input, a_target)
        # we don't want to update the policy yet
        augmented_no_grad = augmented.clone().detach()

        # ... and do the forward pass on the augmented batch
        _output, a_output = classifier_critic(augmented_no_grad)
        d_a_loss = a_output.mean()  # other part of the WGan loss

        # We compute and add the gradient penalty term to the WGan loss
        gp = self.gp_factor * self.gradient_penalty(
            classifier_critic, n_input, augmented_no_grad
        )

        # And finally compute the whole WGan loss and update the
        # classifier-critic model
        discriminator_loss = -d_n_loss + d_a_loss + gp + loss
        discriminator_loss.backward()
        critic_optim.step()

        # TRAIN POLICY:
        # Now that the critic and classifier have been updated, we can update
        # the policy weights
        classifier_critic.requires_grad_(False)
        policy.zero_grad()

        # We need to compute the forward again (recording gradients this time)
        # f(A(X)):
        _output, a_output = classifier_critic(augmented)
        # L(f(A(X)), y):
        _loss = self.penalty_weight * criterion(_output, a_target)

        # Compute the gradients corresponding to the penalty part
        # (retain_graph necessary because we pass twice through the policy)
        _loss.backward(retain_graph=True)

        # And then compute gradients for the WGan part
        a_loss = a_output.mean()
        a_loss.backward(-ones)
        policy_optim.step()

        return discriminator_loss, _loss + loss

    def _train(
        self,
        policy,
        classifier_critic,
        policy_optim,
        critic_optim,
        device,
        train_loader,
        epoch,
        criterion,
        log_interval,
        writer,
    ):
        policy.train()
        classifier_critic.eval()
        run_wgan_loss = 0
        run_pen_loss = 0
        train_loss = 0
        for batch_idx, (data, target, _) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            wgan_loss, pen_loss = self._optim_step(
                data=(data, target),
                policy=policy,
                classifier_critic=classifier_critic,
                policy_optim=policy_optim,
                critic_optim=critic_optim,
                criterion=criterion,
            )
            run_wgan_loss += wgan_loss.item()
            run_pen_loss += pen_loss.item()
            loss = wgan_loss.item() + pen_loss.item()
            train_loss += loss
            if batch_idx % log_interval == 0:
                print(
                    f'Train Epoch: {epoch}'
                    f' [{batch_idx * len(data)}/{len(train_loader.dataset)}'
                    f' ({100. * batch_idx / len(train_loader):.0f}%)]'
                    f'\tLoss: {loss:.6f}'
                    f'\tWgan loss: {wgan_loss.item():.6f}'
                    f'\tPen loss: {pen_loss.item():.6f}'
                )
                n_iter = epoch * len(train_loader) + batch_idx
                if writer is not None:
                    writer.add_scalar(
                        'Train/batch-wgan-loss',
                        wgan_loss,
                        n_iter
                    )
                    writer.add_scalar(
                        'Train/batch-penalty-loss',
                        pen_loss,
                        n_iter
                    )
                    writer.add_scalar(
                        'Train/batch-total-loss',
                        loss,
                        n_iter
                    )
                    writer.add_histogram(
                        "Parameters/probabilities",
                        policy.all_probabilities,
                        n_iter
                    )
                    writer.add_histogram(
                        "Parameters/magnitudes",
                        policy.all_magnitudes,
                        n_iter
                    )
                    writer.add_histogram(
                        "Parameters/weights",
                        policy.all_weights,
                        n_iter
                    )
                    writer.add_histogram(
                        "Grads/probabilities",
                        policy.all_prob_grads,
                        n_iter
                    )
                    writer.add_histogram(
                        "Grads/magnitudes",
                        policy.all_mag_grads,
                        n_iter
                    )
                    writer.add_histogram(
                        "Grads/weights",
                        policy.all_weight_grads,
                        n_iter
                    )
        train_loss /= (batch_idx + 1)
        run_wgan_loss /= (batch_idx + 1)
        run_pen_loss /= (batch_idx + 1)
        print(
            f'\nTraining set -- Running loss: {train_loss:.4f},'
            f' Running WGan loss:  {run_wgan_loss:.4f}'
            f' Running penalty loss:  {run_pen_loss:.4f}\n'
        )
        return train_loss, run_wgan_loss, run_pen_loss

    def _test(self, policy, classifier_critic, device, test_loader, criterion,
              valid=False):
        set_name = 'Test set'
        if valid:
            set_name = 'Validation set'
        classifier_critic.eval()
        policy.eval()
        test_loss = 0
        ground_truth = list()
        predictions = list()
        with torch.no_grad():
            for data, target, _ in test_loader:
                data, target = data.to(device), target.to(device)
                tr_data, _ = policy(data, target)  # XXX: Is this correct?...
                output, _ = classifier_critic(tr_data)
                # sum up batch loss
                test_loss += criterion(output, target).item()
                # get the index of the max log-probability
                pred = output.argmax(dim=1)
                predictions += pred.tolist()
                ground_truth += target.tolist()

        test_loss /= len(test_loader.dataset)
        bal_acc = balanced_accuracy_score(ground_truth, predictions)

        print(
            f'\n{set_name} -- Average loss: {test_loss:.4f},'
            f' Balanced Accuracy: {bal_acc * 100:.1f}%\n'
        )
        return test_loss, bal_acc

    def _search_in_fold(
        self,
        split,
        random_state,
        epochs,
        windows_dataset,
        pretrain_base_path,
        ordered_ch_names,
        sfreq,
        **kwargs
    ):
        # Get train and valid splits from dataset
        fold, subset_ratio, train_subset_idx, valid_idx, test_idx = split

        # Seed global RNGs and local RndomState object for this fold
        self._init_global_and_specific_rngs(
            fold=fold,
            random_state=random_state
        )

        train_set = Subset(windows_dataset, train_subset_idx)
        valid_set = Subset(windows_dataset, valid_idx)
        test_set = Subset(windows_dataset, test_idx)

        # Init retrain bool for this fold
        retrain_this_fold = False

        print(
            f"---------- Fold {fold} out of {self.n_folds} |",
            f"Training size: {len(train_subset_idx)} ----------"
        )

        # Make loaders and get device from params
        (
            train_loader,
            valid_loader,
            test_loader
        ) = self.make_search_valid_loaders(train_set, valid_set, test_set)

        # Load existing history and fetch best accuracy and step index from it
        fold_path = Path(self.training_dir) / f'fold{fold}of{self.n_folds}'
        subset_path = fold_path / f'subset_{subset_ratio}_samples'
        search_history_path = subset_path / "search_history.csv"
        step_idx, best_valid_acc = self.load_history_if_exist(
            search_history_path
        )

        # Create training and validation criteria
        (
            train_criterion,
            valid_criterion,
            test_criterion
        ) = self.instantiate_criteria(train_set, valid_set, test_set)

        # Set tensorboard writer (in a central folder)
        writer = self.set_tensorboard_up(fold=fold, subset_ratio=subset_ratio)

        # Load pretrained model
        trained_model = self.load_trained_classifier(
            pretrain_base_path=pretrain_base_path,
            fold=fold,
            subset_ratio=subset_ratio,
        )

        # Use trained_model to build critic and classifier here
        classifier_critic = ClassifierCritic(trained_model, self.device)

        critic_optim = self.model_params["optimizer"](
            classifier_critic.parameters(),
            lr=self.model_params["lr"],
            betas=(0., 0.999)
        )

        # Create a new DiffAugmentationPolicy object and its optimizer
        policy, policy_optim, best_policy = self.create_policy_and_optimizer(
            ordered_ch_names=ordered_ch_names,
            sfreq=sfreq,
            random_state=self.random_states[fold]["random_state_obj"],
        )

        # We need to checkpoint critic and optimizer together with the policy
        to_checkpoint = {
            'policy': policy,
            'classifier_critic': classifier_critic,
            'policy_optim': policy_optim,
            'critic_optim': critic_optim,
            'best_policy': best_policy,
            'rng_states': self.random_states[fold],
        }

        # Setup checkpointing path, to save parameters for every epochs
        checkpoints_path = subset_path / "checkpoints"
        os.makedirs(checkpoints_path, exist_ok=True)
        ckpt_prefix = 'faster_aa'

        # If checkpoints exist in the correct folder, load the last one to
        # continue the policy training
        start_epoch = self.load_checkpoints(
            checkpoints_path, ckpt_prefix, to_checkpoint
        )

        # Set global and specific RNGs for this fold
        self._set_global_rngs_from_previous_calls(fold=fold)

        results = list()
        start = time()

        log_interval = int(len(train_loader) / 10)
        if log_interval == 0:
            log_interval = 1

        for epoch in range(start_epoch, epochs + start_epoch):
            tr_loss, tr_wgan, tr_pen = self._train(
                policy=policy,
                classifier_critic=classifier_critic,
                device=self.device,
                train_loader=train_loader,
                policy_optim=policy_optim,
                critic_optim=critic_optim,
                epoch=epoch,
                criterion=train_criterion,
                log_interval=log_interval,
                writer=writer,
            )
            valid_loss, valid_acc = self._test(
                policy=policy,
                classifier_critic=classifier_critic,
                device=self.device,
                test_loader=valid_loader,
                criterion=valid_criterion,
                valid=True,
            )
            test_loss, test_acc = self._test(
                policy=policy,
                classifier_critic=classifier_critic,
                device=self.device,
                test_loader=test_loader,
                criterion=test_criterion,
                valid=True,
            )
            time_since_start = time() - start

            # Store best valid accuracy if better then previous values
            # and overwrite best model (output by the search)
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                to_checkpoint["best_policy"] = best_policy = deepcopy(policy)
                # Ask for retraining only if improvement observed
                retrain_this_fold = True
            else:
                print(
                    f">> No validation improvement at this epoch ({epoch}) "
                    f"for fold {fold}"
                )

            if writer is not None:
                writer.add_scalar('Loss/search', tr_loss, epoch)
                writer.add_scalar('Loss/search-wgan', tr_wgan, epoch)
                writer.add_scalar('Loss/search-pen', tr_pen, epoch)
                writer.add_scalar('Loss/valid', valid_loss, epoch)
                writer.add_scalar('Loss/test', test_loss, epoch)
                writer.add_scalar('Accuracy/valid', valid_acc, epoch)
                writer.add_scalar('Accuracy/test', test_acc, epoch)
                writer.add_scalar('Accuracy/best-valid', best_valid_acc, epoch)
                writer.add_scalar('time', time_since_start, epoch)

            results.append({
                "fold": fold,
                "n_folds": self.n_folds,
                "epoch": epoch,
                "step_idx": step_idx,
                "time": time_since_start,
                "search_loss": tr_loss,
                "search_wgan": tr_wgan,
                "search_pen": tr_pen,
                "valid_loss": valid_loss,
                "test_loss": test_loss,
                "valid_bal_acc": valid_acc,
                "test_bal_acc": test_acc,
                "best_valid_bal_acc": best_valid_acc,
            })

            # Store in memory the current global RNG states for future search
            # steps
            self._save_current_global_rng_states(fold=fold)

            # Create a checkpoint with all objects listed in to_checkpoint
            self.checkpoint_stuff(
                to_checkpoint=to_checkpoint,
                checkpoints_path=checkpoints_path,
                ckpt_prefix=ckpt_prefix,
                epoch=epoch,
            )

        # Save search history as a DataFrame.
        if step_idx == 0 or not search_history_path.exists():
            # Start new csv if none is found
            mode = "w"
            header = True
        else:
            mode = "a"
            header = False
        pd.DataFrame(results).to_csv(
            search_history_path,
            header=header,
            index=False,
            mode=mode,
        )

        # Return the best model found up to here, based on (density match)
        # validation accuracy
        return (
            (fold, best_policy.requires_grad_(False)),
            (fold, retrain_this_fold)
        )


class DARTS(GradientBaseSearcher):
    def __init__(
        self,
        training_dir,
        model,
        subpolicies_length=2,
        policy_size_per_fold=5,
        transforms_family=None,
        network_momentum=0.9,
        network_weight_decay=3e-4,
        unrolled=True,
        perturb_policy=False,
        pertub_start=0.03,
        pertub_end=0.3,
        **kwargs
    ):
        super().__init__(
            training_dir,
            model,
            subpolicies_length=subpolicies_length,
            policy_size_per_fold=policy_size_per_fold,
            transforms_family=transforms_family,
            **kwargs
        )
        self.network_momentum = network_momentum
        self.network_weight_decay = network_weight_decay
        self.unrolled = unrolled
        self.perturb_policy = perturb_policy
        self._pertub_start = pertub_start
        self._pertub_end = pertub_end

    def make_search_valid_loaders(self, search_set, valid_set, test_set):
        train_loader, _, test_loader = super().make_search_valid_loaders(
            search_set,
            valid_set,
            test_set
        )
        valid_loader = DataLoader(
            valid_set,
            batch_size=train_loader.batch_size,
            pin_memory=self.model_params.get(
                "iterator_train__pin_memory",
                False
            ),
            num_workers=self.model_params.get(
                'iterator_train__num_workers',
                0
            ),
            worker_init_fn=self.model_params.get(
                "iterator_train__worker_init_fn",
                None
            ),
            multiprocessing_context=self.model_params.get(
                "iterator_train__multiprocessing_context",
                None
            ),
            drop_last=True,
        )
        return train_loader, valid_loader, test_loader

    def instantiate_criteria(self, search_set, valid_set, test_set):
        (
            search_criterion,
            valid_eval_criterion,
            test_criterion
        ) = super().instantiate_criteria(search_set, valid_set, test_set)
        valid_class_weights = make_class_proportions_tensor(
            valid_set,
            self.balanced_loss,
            self.device,
        )
        valid_search_criterion = self.model_params["criterion"](
            weight=valid_class_weights,
        )
        return (
            search_criterion,
            valid_search_criterion,
            valid_eval_criterion,
            test_criterion
        )

    def warmstart_with_adda_policy(
        self,
        classwise_policy,
        model,
        warmstart_base_path,
        fold,
        subset_ratio,
        epoch,
        ckpt_prefix,
        ordered_ch_names,
        sfreq,
        diff_transforms,
    ):
        warmstart_path = self._get_warmstart_path(
            warmstart_base_path, fold, subset_ratio
        )
        if warmstart_path is not None:
            checkpoints_path = Path(warmstart_path) / "checkpoints"
            existing_checkpoints = sorted(
                checkpoints_path.glob(f'{ckpt_prefix}_e{epoch}*.pt'),
                key=lambda t: t.stat().st_mtime
            )
            if len(existing_checkpoints) > 0:
                # If one exists, load the one modified last
                checkpoint_dict = torch.load(
                    existing_checkpoints[-1], map_location=self.device,
                )
                std_policy = self._make_standard_policy(
                    ordered_ch_names=ordered_ch_names,
                    sfreq=sfreq,
                    diff_transforms=diff_transforms,
                    random_state=self.random_states[fold]["random_state_obj"],
                )
                std_policy.load_state_dict(checkpoint_dict["best_policy"])
                model.load_state_dict(checkpoint_dict["model"])
                print(
                    f"CADDA warmstarted with ADDA weights from epoch {epoch}."
                )
            ws_classwise_policy = _transfer_weights_to_cw_policy(
                classwise_policy, std_policy
            )
            return ws_classwise_policy, model
        return classwise_policy, model

    def _compute_unrolled_model(
        self,
        policy,
        model,
        model_optim,
        input,
        target,
        train_criterion,
        eta,
    ):
        # Model Forward pass
        # (might require a no_grad for policy depending on the last doubt in
        # unrolled_backward)
        # with torch.no_grad():
        aug_input, _ = policy(input, target)
        output = model(aug_input)
        loss = train_criterion(output, target)

        # Get its parameters
        theta = _concat(model.parameters()).data

        # Apply momentum (?)
        try:
            moment = _concat(
                model_optim.state[v]['momentum_buffer']
                for v in model.parameters()
            ).mul_(self.network_momentum)
        except Exception:
            moment = torch.zeros_like(theta)

        # Compute model SGD gradients
        dtheta = _concat(
            torch.autograd.grad(loss, model.parameters())
        ).data + self.network_weight_decay * theta

        # Update model parameters and build a new copy of the model with them
        unrolled_model = self._construct_model_from_theta(
            theta.sub(other=moment + dtheta, alpha=eta), model,
        )
        return unrolled_model

    def _construct_model_from_theta(self, theta, model):
        # XXX: It seems that this is just a deepclone, but I'm not 100% sure
        # yet (because of the arch_params copy)
        # model_new = self.model.new()
        model_new = deepcopy(model)
        model_dict = model.state_dict()

        params, offset = {}, 0
        for k, v in model.named_parameters():
            v_length = np.prod(v.size())
            params[k] = theta[offset: offset + v_length].view(v.size())
            offset += v_length

        assert offset == len(theta)
        model_dict.update(params)
        model_new.load_state_dict(model_dict)
        return model_new.to(self.device)

    def _hessian_vector_product(
        self,
        policy,
        model,
        vector,
        input,
        target,
        train_criterion,
        r=1e-2
    ):
        """ Computes the hessian product in equation 8"""
        # epsilon (with the heuristic described in the footnote)
        R = r / _concat(vector).norm()

        implicit_gradients = dict()

        # take w and add epsilon * grad_w L_val(w', a)  <- w'=unrolled_model
        for p, v in zip(model.parameters(), vector):
            p.data.add_(other=v, alpha=R)

        aug_input, _ = policy(input, target)
        output = model(aug_input)
        loss = train_criterion(output, target)  # L_tr(w_pos, a)
        policy_params = [p for p in policy.parameters() if p.requires_grad]
        grads_p = torch.autograd.grad(
            loss,
            policy_params,
            allow_unused=True,
            retain_graph=True
        )
        implicit_gradients = {
            p: g for p, g in zip(policy_params, grads_p)
        }

        # take w and substract 2 * epsilon * grad_w L_val(unrolled_model, a)
        for p, v in zip(model.parameters(), vector):
            p.data.sub_(other=v, alpha=2*R)
        output = model(aug_input)
        loss = train_criterion(output, target)  # L_tr(w_neg, a)
        # grad_a L_tr(w_neg, a)
        grads_n = torch.autograd.grad(loss, policy_params, allow_unused=True)

        # Compute equation 8
        for (p, g_p), g_n in zip(implicit_gradients.items(), grads_n):
            if g_p is not None:
                implicit_gradients[p] = (g_p - g_n).div_(2*R)

        # restore the original model parameters
        for p, v in zip(model.parameters(), vector):
            p.data.add_(other=v, alpha=R)

        # return equation 8
        return implicit_gradients

    def _backward_step_unrolled(
        self,
        policy,
        model,
        model_optim,
        input_train,
        target_train,
        input_valid,
        target_valid,
        train_criterion,
        valid_criterion,
        eta,
    ):
        # Unroll last model update...
        unrolled_model = self._compute_unrolled_model(
            policy=policy,
            model=model,
            model_optim=model_optim,
            input=input_train,
            target=target_train,
            train_criterion=train_criterion,
            eta=eta,
        )

        # ... and do a forward pass on the validation batch with it
        unrolled_output = unrolled_model(input_valid)
        # Compute L_val(w', a) from eq 7
        unrolled_loss = valid_criterion(unrolled_output, target_valid)

        # Compute both its gradients: grad_w and grad_a L_val(w', a),
        # used in eq 7, 8
        unrolled_loss.backward()

        # Compute the hessian gradient product in equation 8
        # grad_w L_val(w', a)
        vector = [v.grad.data for v in unrolled_model.parameters()]

        implicit_grads = self._hessian_vector_product(
            policy=policy,
            model=model,
            vector=vector,
            input=input_train,
            target=target_train,
            train_criterion=train_criterion,
        )

        # Compute the upper-level gradient (equation 7)

        # XXX: I feel that the following block is just transferring the updated
        # gradients from unrolled.
        # Either the second block is useless in our case, or I should make a
        # copy of the policy when doing the unrolled forward.
        # As there is not forward pass of the unrolled model on the training
        # (i.e. with augmentation), I guess I don't need to make a copy of the
        # policy and can update it directly. This is probably due to the fact
        # that the upper-level variable is only in the lower-level (training)
        # in our case, compared to their problem where is is in both levels
        # (i.e. both train and valid set losses)

        # grad_a L_val(unrolled_model, a)
        for p, ig in implicit_grads.items():
            if ig is not None:
                if p.grad is None:
                    p.grad = Variable(-eta * ig.data)
                else:
                    p.grad.data.sub_(other=ig.data, alpha=eta)

    def _backward_step(
        self,
        model,
        input_valid,
        target_valid,
        valid_criterion,
    ):
        output_valid = model(input_valid)
        loss = valid_criterion(output_valid, target_valid)
        loss.backward()

    def _update_policy(
        self,
        policy,
        model,
        model_optim,
        policy_optim,
        input_train,
        target_train,
        input_valid,
        target_valid,
        train_criterion,
        valid_criterion,
        eta,
    ):
        policy_optim.zero_grad()
        if self.unrolled:
            self._backward_step_unrolled(
                policy=policy,
                model=model,
                model_optim=model_optim,
                input_train=input_train,
                target_train=target_train,
                input_valid=input_valid,
                target_valid=target_valid,
                train_criterion=train_criterion,
                valid_criterion=valid_criterion,
                eta=eta,
            )
        else:
            self._backward_step(
                model=model,
                input_valid=input_valid,
                target_valid=target_valid,
                valid_criterion=valid_criterion,
            )
        policy_optim.step()

    def _train(
        self,
        policy,
        model,
        policy_optim,
        model_optim,
        train_loader,
        valid_loader,
        train_criterion,
        valid_criterion,
        epoch,
        log_interval,
        writer,
    ):
        ground_truth = list()
        predictions = list()
        train_loss = 0
        for batch_idx, (data, target, _) in enumerate(train_loader):
            model.train()

            # XXX: This should be unecessary with current pytorch version
            # XXX: Not sure whether the async is important though...
            data = Variable(data, requires_grad=False).to(self.device)
            target = Variable(target, requires_grad=False).to(self.device)

            # get a random minibatch from the search queue with replacement
            data_search, target_search, _ = next(iter(valid_loader))
            data_search = Variable(
                data_search, requires_grad=False
            ).to(self.device)
            target_search = Variable(
                target_search, requires_grad=False
            ).to(self.device)

            self._update_policy(
                policy=policy,
                model=model,
                model_optim=model_optim,
                policy_optim=policy_optim,
                input_train=data,
                target_train=target,
                input_valid=data_search,
                target_valid=target_search,
                train_criterion=train_criterion,
                valid_criterion=valid_criterion,
                eta=self.model_params["lr"],
            )

            model_optim.zero_grad()
            with torch.no_grad():
                if self.perturb_policy:
                    epsilon = self._pertub_start + epoch * (
                        self._pertub_end - self._pertub_start
                    ) / self._epochs
                    policy.perturbation_on(epsilon)
                aug_data, _ = policy(data, target)
                policy.perturbation_off()
            output = model(aug_data)
            loss = train_criterion(output, target)

            loss.backward()
            # XXX I'm not sure whether the clipping is required so I'm
            # commenting it for now
            # nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
            model_optim.step()

            ground_truth += target.tolist()
            pred = output.clone().detach().argmax(dim=1)
            predictions += pred.tolist()
            train_loss += loss.item()

            if batch_idx % log_interval == 0:
                print(
                    f'Train Epoch: {epoch}'
                    f' [{batch_idx * len(data)}/{len(train_loader.dataset)}'
                    f' ({100. * batch_idx / len(train_loader):.0f}%)]'
                    f'\tLoss: {loss.item():.6f}'
                )
                n_iter = epoch * len(train_loader) + batch_idx
                if writer is not None:
                    writer.add_histogram(
                        "Parameters/probabilities",
                        policy.all_probabilities,
                        n_iter
                    )
                    writer.add_histogram(
                        "Parameters/magnitudes",
                        policy.all_magnitudes,
                        n_iter
                    )
                    writer.add_histogram(
                        "Parameters/weights",
                        policy.all_weights,
                        n_iter
                    )
                    if (
                        policy.all_prob_grads is not None and
                        len(policy.all_prob_grads) > 0
                    ):
                        writer.add_histogram(
                            "Grads/probabilities",
                            policy.all_prob_grads,
                            n_iter
                        )
                    if (
                        policy.all_mag_grads is not None and
                        len(policy.all_mag_grads) > 0
                    ):
                        writer.add_histogram(
                            "Grads/magnitudes",
                            policy.all_mag_grads,
                            n_iter
                        )
                    if (
                        policy.all_weight_grads is not None
                        and len(policy.all_weight_grads) > 0
                    ):
                        writer.add_histogram(
                            "Grads/weights",
                            policy.all_weight_grads,
                            n_iter
                        )

        train_loss /= (batch_idx + 1)
        bal_acc = balanced_accuracy_score(ground_truth, predictions)
        print(
            f'\nTraining set -- Running loss: {train_loss:.4f},'
            f' Balanced Accuracy:  {bal_acc * 100:.1f}%\n'
        )
        return train_loss, bal_acc

    def _test(self, model, test_loader, criterion, valid=False):
        set_name = 'Test set'
        if valid:
            set_name = 'Validation set'
        model.eval()
        test_loss = 0
        ground_truth = list()
        predictions = list()
        with torch.no_grad():
            for i, (data, target, _) in enumerate(test_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = model(data)
                # sum up batch loss
                test_loss += criterion(output, target).item()
                # get the index of the max log-probability
                pred = output.argmax(dim=1)
                predictions += pred.tolist()
                ground_truth += target.tolist()

        test_loss /= len(test_loader.dataset)
        bal_acc = balanced_accuracy_score(ground_truth, predictions)

        print(
            f'\n{set_name} -- Average loss: {test_loss:.4f},'
            f' Balanced Accuracy: {bal_acc * 100:.1f}%\n'
        )
        return test_loss, bal_acc

    def _search_in_fold(
        self,
        split,
        random_state,
        epochs,
        windows_dataset,
        pretrain_base_path,
        warmstart_base_path,
        ordered_ch_names,
        sfreq,
        warmstart_epoch,
        *args,
        **kwargs
    ):
        # Get train and valid splits from dataset
        fold, subset_ratio, train_subset_idx, valid_idx, test_idx = split

        # Seed global RNGs and local RndomState object for this fold
        self._init_global_and_specific_rngs(
            fold=fold,
            random_state=random_state
        )

        # Split the dataset into train, valid, test using precomputed indices
        train_set = Subset(windows_dataset, train_subset_idx)
        valid_set = Subset(windows_dataset, valid_idx)
        test_set = Subset(windows_dataset, test_idx)

        # Init retrain bool for this fold
        retrain_this_fold = False

        print(
            f"---------- Fold {fold} out of {self.n_folds} |",
            f"Training size: {len(train_subset_idx)} ----------"
        )

        # Load existing history and fetch best accuracy and step index from it
        fold_path = Path(self.training_dir) / f'fold{fold}of{self.n_folds}'
        subset_path = fold_path / f'subset_{subset_ratio}_samples'
        search_history_path = subset_path / "search_history.csv"
        step_idx, best_valid_acc = self.load_history_if_exist(
            search_history_path
        )

        # Create training and validation criteria
        (
            train_criterion,
            valid_search_criterion,
            valid_eval_criterion,
            test_criterion
        ) = self.instantiate_criteria(train_set, valid_set, test_set)

        # Set tensorboard writer (in a central folder)
        writer = self.set_tensorboard_up(fold=fold, subset_ratio=subset_ratio)

        # Load pretrained model when applicable
        # Otherwise just make a copy of the model (in order to train it
        # independently for each fold)
        model = self.load_trained_classifier(
            pretrain_base_path=pretrain_base_path,
            fold=fold,
            subset_ratio=subset_ratio,
        )

        # The model optimizer during search is standard SGD with momentum,
        # to facilitate unrolling
        model_optim = torch.optim.SGD(
            model.parameters(),
            self.model_params["lr"],
            momentum=self.network_momentum,
            weight_decay=self.network_weight_decay
        )

        # Create a new DiffAugmentationPolicy object and its optimizer
        policy, policy_optim, best_policy = self.create_policy_and_optimizer(
            ordered_ch_names=ordered_ch_names,
            sfreq=sfreq,
            random_state=self.random_states[fold]["random_state_obj"],
        )

        # We need to checkpoint the following elements for training persistence
        # across searches
        to_checkpoint = {
            'policy': policy,
            'model': model,
            'policy_optim': policy_optim,
            'model_optim': model_optim,
            'best_policy': best_policy,
            'rng_states': self.random_states[fold],
        }

        # Setup checkpointing path, to save parameters for every epochs
        checkpoints_path = subset_path / "checkpoints"
        os.makedirs(checkpoints_path, exist_ok=True)
        ckpt_prefix = 'darts'

        # If checkpoints exist in the correct folder, load the last one to
        # continue the policy training (including previoud RNG states)
        start_epoch = self.load_checkpoints(
            checkpoints_path, ckpt_prefix, to_checkpoint
        )

        # Set global and specific RNGs for this fold
        self._set_global_rngs_from_previous_calls(fold=fold)

        # Warmstart only if in first step
        if warmstart_base_path is not None and step_idx == 0:
            if self.classwise:
                # Load and convert desired ADDA weights into CADDA policy
                policy, model = self.warmstart_with_adda_policy(
                    policy,
                    model,
                    warmstart_base_path=warmstart_base_path,
                    fold=fold,
                    subset_ratio=subset_ratio,
                    epoch=warmstart_epoch,
                    ckpt_prefix=ckpt_prefix,
                    ordered_ch_names=ordered_ch_names,
                    sfreq=sfreq,
                    diff_transforms=None,
                )
                to_checkpoint["best_policy"] = best_policy = deepcopy(policy)

                # Load history and fetch best accuracy and step index from ADDA
                warmstart_path = self._get_warmstart_path(
                    warmstart_base_path, fold, subset_ratio
                )
                warmstart_search_history_path = Path(
                    warmstart_path
                ) / "search_history.csv"

                if not search_history_path.exists():
                    copy(warmstart_search_history_path, search_history_path)
                else:
                    raise ValueError(
                        "Found existing history file during warmstarting in",
                        search_history_path,
                        "Stopping here to avoid erasing previous results."
                    )

                step_idx, best_valid_acc = self.load_history_if_exist(
                    warmstart_search_history_path,
                    epoch_to_fetch=warmstart_epoch,
                )

                # Set start epoch as the warmstarting one
                start_epoch = warmstart_epoch + 1
            else:
                raise NotImplementedError(
                    "Only ADDA to CADDA warmstarting has been implemented."
                )

        # Make loaders and get device from params
        (
            train_loader,
            valid_loader,
            test_loader
        ) = self.make_search_valid_loaders(train_set, valid_set, test_set)

        results = list()
        start = time()

        log_interval = int(len(train_loader) / 10)
        if log_interval == 0:
            log_interval = 1

        for epoch in range(start_epoch, epochs + start_epoch):
            tr_loss, tr_acc = self._train(
                policy=policy,
                model=model,
                policy_optim=policy_optim,
                model_optim=model_optim,
                train_loader=train_loader,
                valid_loader=valid_loader,
                train_criterion=train_criterion,
                valid_criterion=valid_search_criterion,
                epoch=epoch,
                log_interval=log_interval,
                writer=writer,
            )
            valid_loss, valid_acc = self._test(
                model=model,
                test_loader=valid_loader,
                criterion=valid_eval_criterion,
                valid=True
            )
            test_loss, test_acc = self._test(
                model=model,
                test_loader=test_loader,
                criterion=test_criterion,
                valid=False
            )
            time_since_start = time() - start

            # Store best valid accuracy if better then previous values
            # and overwrite best model (output by the search)
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                to_checkpoint["best_policy"] = best_policy = deepcopy(policy)
                # Ask for retraining only if improvement observed
                retrain_this_fold = True
            else:
                print(
                    f">> No validation improvement at this epoch ({epoch}) "
                    f"for fold {fold}"
                )

            if writer is not None:
                writer.add_scalar('Loss/search', tr_loss, epoch)
                writer.add_scalar('Loss/valid', valid_loss, epoch)
                writer.add_scalar('Loss/test', test_loss, epoch)
                writer.add_scalar('Accuracy/search', tr_acc, epoch)
                writer.add_scalar('Accuracy/valid', valid_acc, epoch)
                writer.add_scalar('Accuracy/test', test_acc, epoch)
                writer.add_scalar('Accuracy/best-valid', best_valid_acc, epoch)
                writer.add_scalar('time', time_since_start, epoch)

            results.append({
                "fold": fold,
                "n_folds": self.n_folds,
                "epoch": epoch,
                "step_idx": step_idx,
                "time": time_since_start,
                "search_loss": tr_loss,
                "valid_loss": valid_loss,
                "test_loss": test_loss,
                "search_bal_acc": tr_acc,
                "valid_bal_acc": valid_acc,
                "test_bal_acc": test_acc,
                "best_valid_bal_acc": best_valid_acc,
            })

            # Store in memory the current global RNG states for future search
            # steps
            self._save_current_global_rng_states(fold=fold)

            # Create a checkpoint with all objects listed in to_checkpoint
            self.checkpoint_stuff(
                to_checkpoint=to_checkpoint,
                checkpoints_path=checkpoints_path,
                ckpt_prefix=ckpt_prefix,
                epoch=epoch,
            )

        # Save search history as a DataFrame.
        if step_idx == 0:  # Start new csv if none is found
            mode = "w"
            header = True
        else:
            mode = "a"
            header = False
        pd.DataFrame(results).to_csv(
            search_history_path,
            header=header,
            index=False,
            mode=mode,
        )

        # Return the best model found up to here, based on (density match)
        # validation accuracy
        return (
            (fold, best_policy.requires_grad_(False)),
            (fold, retrain_this_fold)
        )


GRAD_SEARCH_ALGOS = {
    "match": DiffDensityMatching,
    "fraa": FasterAutoAugment,
    "darts": DARTS,
}


def split_search_epochs_in_steps(tot_epochs, step_size):
    """ Takes a total number of epochs and split into steps for assessment """
    n_full_steps = int(tot_epochs // step_size)
    splitted_trials = [step_size] * n_full_steps
    rest = tot_epochs % step_size
    if rest > 0:
        splitted_trials += [rest]
    return np.array(splitted_trials)


def evaluate_diff_policy_search(
    training_dir,
    windows_dataset,
    sfreq,
    ordered_ch_names,
    search_epochs=50,
    assess_epochs=300,
    eval_step=1,
    subpolicies_length=2,
    policy_size_per_fold=5,
    algo="match",
    n_classes=5,
    pretrain_base_path=None,
    warmstart_base_path=None,
    warmstart_epoch=None,
    device=None,
    lr=1e-3,
    batch_size=128,
    num_workers=4,
    early_stop=True,
    model_to_use=None,
    data_ratio=None,
    grouped_subset=True,
    n_jobs=1,
    train_policy_on="same",
    grad_est=None,
    random_state=None,
    verbose=False,
    **kwargs
):
    """Look for optimal policy using desired optimization algorithm and compute
    crossvalidated test performance

    Parameters
    ----------
    training_dir : str
        Directory where checkpoints, search results and test set evaluation
        results are saved.
    windows_dataset : torch.util.data.Dataset
        Dataset to use for training, validation and test (after splitting).
    sfreq : float
        Sampling frequency of input data.
    ordered_ch_names : list
        List of strings representing the channels of the montage considered.
        Only used for instantiating transforms needing this information. Has to
        be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channel will be used to compute approximate sensors
        positions from a standard 10-20 montage.
    search_epochs : int, optional
        Total number of epochs for searching the best policy. Defaults to 50.
    assess_epochs : int, optional
        Number of epochs used for each assessment training with current learned
        policy. Defaults to 300.
    eval_step : int, optional
        Number of consecutive search epochs to run before stopping the search
        and starting a new assessment of the current learned policy. Defaults
        to 1 (assess after every search epoch).
    subpolicies_length : int, optional
        Number of consecutive Transforms in a subpolicy, by default 2
    policy_size_per_fold : int, optional
        Size of policies (in number of subpolicies), by default 5
    algo : str, optional
        Key encoding the optimization algorithm to use for the search. Can be
        either 'fraa' (default) for Faster AutoAugment with the WGan loss, or
        'match' for differentiable density match.
    n_classes : int, optional
        Number of classes in the dataset, by default 5
    pretrain_base_path : str, optional
        Path to training folder where the checkpoints of a trained model are.
        Used for density matching algorithms and for weightsharing. Default to
        None.
    warmstart_base_path : str, optional
        Path to training folder where the checkpoints of a trained model and
        policy are, as well as search history. Used to warmstart the algorithm.
        Default to None.
    device : str, optional
        Device to use for training, by default None
    lr : float, optional
        Learning rate, by default 1e-3
    batch_size : int, optional
        Batch size, by default 128
    num_workers : int, optional
        Number of workers used for data loading. By default 4
    early_stop : bool, optional
        Whether to carry earlystopping during training. By default True
    model_to_use : str | None, optional
        Defines which net should be used. By default (None) will use
        SleepStager. If set to 'lin' will use one layer linear net.
    data_ratio : float | None, optional
        Float between 0 and 1 or None. Will be used to build a
        subset of the cross-validated training sets (valid and test sets
        are conserved). Omitting it or setting it to None, is equivalent to
        setting it to [1.] (using the whole training set). By default None.
    grouped_subset : bool, optional
        Whether to compute training subsets taking groups (subjects) into
        account or not. When False, stratified spliting will be used to
        build the subsets. By default True.
    n_jobs : int, optional
        Number of workers to use for parallelizing across splits. By
        default 1.
    grad_est : str, optional
        Defines what gradient estimator to use for the operations weights.
        When None, a softmax of the weights will be used, producing a convex
        combination of all operations in the forward and backward pass. When
        "gumbel", the straight-through Gumbel-Softmax trick is used where
        a single operation is sampled per batch in the forward and
        gumbel-softmax distribution is optimized in the backward pass. If
        "relax", the behavior is similar to "gumbel" but where the gradient is
        estimated with no bias, using the RELAX estimator. Defaults to None.
    random_state : int | numpy.random.RandomState | None, optional
        Used to seed random number generator, by default None
    verbose : bool, optional
        By default False.
    """
    # Parse arguments and prepare elements for the search and training
    _, _, model, model_params, shared_callbacks = prepare_training(
        windows_dataset,
        lr=lr,
        batch_size=batch_size,
        num_workers=num_workers,
        device=device,
        early_stop=early_stop,
        sfreq=sfreq,
        n_classes=n_classes,
        model_to_use=model_to_use,
        random_state=random_state,
    )

    # Create reuseable search class
    searcher = GRAD_SEARCH_ALGOS[algo](
        training_dir=training_dir,
        model=model,
        subpolicies_length=subpolicies_length,
        policy_size_per_fold=policy_size_per_fold,
        model_params=model_params,
        shared_callbacks=shared_callbacks,
        balanced_loss=True,  # Not settable for now
        monitor='valid_bal_acc_best',  # Not settable for now
        should_checkpoint=False,  # To restart model between assessments
        should_load_state=False,  # To restart model between assessments
        random_state=random_state,
        grad_est=grad_est,
        **kwargs
    )

    # Split the total number of trials in steps and start counters
    steps = split_search_epochs_in_steps(search_epochs, eval_step)
    curr_search_epochs = 0
    curr_duration = 0

    # Load results if they exist and infer step_idx, curr_search_epochs, etc.
    if warmstart_base_path is None or warmstart_epoch is None:
        first_step, last_results = load_retrain_results_if_exist(training_dir)
    else:
        # But load from warmstart when applicable
        first_step, last_results = load_retrain_results_if_exist(
            warmstart_base_path, epoch=warmstart_epoch,
        )

    if last_results is not None:
        curr_search_epochs = last_results["tot_search_epochs"]
        curr_duration = last_results["tot_search_duration"]

    for step_idx, step_search_epochs in enumerate(steps):
        # Add step offset in case we are continuing a training
        step_idx += first_step
        print(f"=== STEP {step_idx} ===")

        # Continue policy search with new epochs budget for this step
        start_step_time = time()
        searcher.search_policy(
            windows_dataset,
            epochs=step_search_epochs,
            n_jobs=n_jobs,
            verbose=verbose,
            pretrain_base_path=pretrain_base_path,
            warmstart_base_path=warmstart_base_path,
            ordered_ch_names=ordered_ch_names,
            sfreq=sfreq,
            data_ratio=data_ratio,
            grouped_subset=grouped_subset,
            train_policy_on=train_policy_on,
            warmstart_epoch=warmstart_epoch,
        )
        step_duration = time() - start_step_time

        # Increment time and search epochs counters
        curr_duration += step_duration
        curr_search_epochs += step_search_epochs

        # Convert learned differentiable policy into regular one to use in
        # training
        searcher.convert_policy(step_idx)

        # With the timer paused, assess the learned policy up to now
        results = searcher.learning_curve(
            windows_dataset=windows_dataset,
            epochs=assess_epochs,
            data_ratios=data_ratio,
            grouped_subset=grouped_subset,
            n_jobs=n_jobs,
            verbose=verbose,
            warmstart_base_path=warmstart_base_path,
            warmstart_epoch=warmstart_epoch,
        )

        # Add search duration and step information to results and save it
        results['step_idx'] = step_idx
        results['step_search_duration'] = step_duration
        results['tot_search_duration'] = curr_duration
        results['tot_search_epochs'] = curr_search_epochs
        save_path = Path(training_dir) / 'search_perf_results.csv'
        if step_idx == 0 or not save_path.exists():
            results.to_csv(save_path, index=False)
        else:
            results.to_csv(
                save_path,
                index=False,
                header=False,
                mode='a'  # Append to the end of the file
            )


def make_tf_from_stage_block(
    stage_block,
    tf_mapping,
    ordered_ch_names,
    sfreq,
    random_state=None
):
    p = stage_block.loc[0, 'probability']
    m = stage_block.loc[0, 'magnitude']
    return tf_mapping[stage_block.loc[0, 'operation']](
        probability=p,
        magnitude=m,
        ordered_ch_names=ordered_ch_names,
        sfreq=sfreq,
        random_state=random_state
    )


def retrieve_learned_policy_from_record(
    learned_policies_record,
    transforms_mapping,
    step,
    ordered_ch_names,
    sfreq,
    random_state,
):
    if "class" in learned_policies_record:
        all_classes = learned_policies_record["class"].unique()
        classwise = True
    else:
        all_classes = ["all-classes"]
        classwise = False
    step_learned_policy = dict()
    retrain_rng = check_random_state(random_state)
    for fold in learned_policies_record["fold"].unique():
        step_policy_block = learned_policies_record.query(
            "step_idx == @step and fold == @fold"
        ).reset_index(drop=True)
        step_policy_elements = list()
        for subpol_idx in step_policy_block["subpolicy_idx"].unique():
            step_subpolicy_block = step_policy_block.query(
                "subpolicy_idx == @subpol_idx"
            ).reset_index(drop=True)
            if classwise:
                subpol_per_class = dict()
                for cls in all_classes:
                    step_class_block = step_subpolicy_block[
                        step_subpolicy_block["class"] == cls
                    ].reset_index(drop=True)
                    list_of_tfs = list()
                    for stage in step_class_block["transform_idx"].unique():
                        step_stage_block = step_class_block.query(
                            "transform_idx == @stage"
                        ).reset_index(drop=True)
                        stage_tf = make_tf_from_stage_block(
                            stage_block=step_stage_block,
                            tf_mapping=transforms_mapping,
                            ordered_ch_names=ordered_ch_names,
                            sfreq=sfreq,
                            random_state=retrain_rng
                        )
                        list_of_tfs.append(stage_tf)
                    subpol_per_class[cls] = Compose(list_of_tfs)
                cw_subpolicy = ClasswiseSubpolicy(subpol_per_class)
                step_policy_elements.append(cw_subpolicy)
            else:
                list_of_tfs = list()
                for stage in step_subpolicy_block["transform_idx"].unique():
                    step_stage_block = step_subpolicy_block.query(
                        "transform_idx == @stage"
                    ).reset_index(drop=True)
                    stage_tf = make_tf_from_stage_block(
                        stage_block=step_stage_block,
                        tf_mapping=transforms_mapping,
                        ordered_ch_names=ordered_ch_names,
                        sfreq=sfreq,
                        random_state=retrain_rng
                    )
                    list_of_tfs.append(stage_tf)
                subpolicy = Compose(list_of_tfs)
                step_policy_elements.append(subpolicy)
        step_learned_policy[fold] = AugmentationPolicy(
            step_policy_elements, random_state=retrain_rng
        )
    return step_learned_policy


def replay_diff_policy_search_eval(
    training_dir,
    dir_to_replay,
    windows_dataset,
    sfreq,
    ordered_ch_names,
    search_epochs=50,  # Unused
    assess_epochs=300,
    eval_step=1,
    subpolicies_length=2,
    policy_size_per_fold=5,
    algo="match",
    n_classes=5,
    warmstart_base_path=None,
    device=None,
    lr=1e-3,
    batch_size=128,
    num_workers=4,
    early_stop=True,
    model_to_use=None,
    data_ratio=None,
    grouped_subset=True,
    n_jobs=1,
    train_policy_on="same",
    random_state=None,
    verbose=False,
    **kwargs
):
    """Look for optimal policy using desired optimization algorithm and compute
    crossvalidated test performance

    Parameters
    ----------
    training_dir : str
        Directory where checkpoints, search results and test set evaluation
        results are saved.
    windows_dataset : torch.util.data.Dataset
        Dataset to use for training, validation and test (after splitting).
    sfreq : float
        Sampling frequency of input data.
    ordered_ch_names : list
        List of strings representing the channels of the montage considered.
        Only used for instantiating transforms needing this information. Has to
        be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channel will be used to compute approximate sensors
        positions from a standard 10-20 montage.
    search_epochs : int, optional
        Total number of epochs for searching the best policy. Defaults to 50.
    assess_epochs : int, optional
        Number of epochs used for each assessment training with current learned
        policy. Defaults to 300.
    eval_step : int, optional
        Number of consecutive search epochs to run before stopping the search
        and starting a new assessment of the current learned policy. Defaults
        to 1 (assess after every search epoch).
    subpolicies_length : int, optional
        Number of consecutive Transforms in a subpolicy, by default 2
    policy_size_per_fold : int, optional
        Size of policies (in number of subpolicies), by default 5
    algo : str, optional
        Key encoding the optimization algorithm to use for the search. Can be
        either 'fraa' (default) for Faster AutoAugment with the WGan loss, or
        'match' for differentiable density match.
    n_classes : int, optional
        Number of classes in the dataset, by default 5
    warmstart_base_path : str, optional
        Path to training folder where the checkpoints of a trained model are.
        Used to warmstart the algorithm. Default to None.
    device : str, optional
        Device to use for training, by default None
    lr : float, optional
        Learning rate, by default 1e-3
    batch_size : int, optional
        Batch size, by default 128
    num_workers : int, optional
        Number of workers used for data loading. By default 4
    early_stop : bool, optional
        Whether to carry earlystopping during training. By default True
    model_to_use : str | None, optional
        Defines which net should be used. By default (None) will use
        SleepStager. If set to 'lin' will use one layer linear net.
    data_ratio : float | None, optional
        Float between 0 and 1 or None. Will be used to build a
        subset of the cross-validated training sets (valid and test sets
        are conserved). Omitting it or setting it to None, is equivalent to
        setting it to [1.] (using the whole training set). By default None.
    grouped_subset : bool, optional
        Whether to compute training subsets taking groups (subjects) into
        account or not. When False, stratified spliting will be used to
        build the subsets. By default True.
    n_jobs : int, optional
        Number of workers to use for parallelizing across splits. By
        default 1.
    random_state : int | numpy.random.RandomState | None, optional
        Used to seed random number generator, by default None
    verbose : bool, optional
        By default False.
    """
    # Parse arguments and prepare elements for the search and training
    _, _, model, model_params, shared_callbacks = prepare_training(
        windows_dataset,
        lr=lr,
        batch_size=batch_size,
        num_workers=num_workers,
        device=device,
        early_stop=early_stop,
        sfreq=sfreq,
        n_classes=n_classes,
        model_to_use=model_to_use,
        random_state=random_state,
    )

    # Create reuseable search class
    searcher = GRAD_SEARCH_ALGOS[algo](
        training_dir=training_dir,
        model=model,
        subpolicies_length=subpolicies_length,
        policy_size_per_fold=policy_size_per_fold,
        model_params=model_params,
        shared_callbacks=shared_callbacks,
        balanced_loss=True,  # Not settable for now
        monitor='valid_bal_acc_best',  # Not settable for now
        should_checkpoint=False,  # To restart model between assessments
        should_load_state=False,  # To restart model between assessments
        random_state=random_state,
        **kwargs
    )

    # Load search record to retrive search time, steps, etc.
    search_record_path = Path(dir_to_replay) / 'search_perf_results.csv'
    search_record = pd.read_csv(search_record_path)
    steps_made = search_record["step_idx"].unique()

    # Load learned policies record
    learned_policies_path = Path(dir_to_replay) / 'learned_policies.csv'
    learned_policies_record = pd.read_csv(learned_policies_path)

    # Mapping from name to Transform
    transforms_mapping = dict(
        [m for m in inspect.getmembers(transforms, inspect.isclass)]
    )

    for step in steps_made:
        print(f"=== STEP {step} ===")

        # Fetch running search metrics to copy
        curr_duration = search_record.query(
            "step_idx == @step"
        )["tot_search_duration"].reset_index(drop=True)[0]
        step_duration = search_record.query(
            "step_idx == @step"
        )["step_search_duration"].reset_index(drop=True)[0]
        curr_search_epochs = search_record.query(
            "step_idx == @step"
        )["tot_search_epochs"].reset_index(drop=True)[0]

        # Look up learned policy for each fold at this step and instantiate it
        searcher.learned_policies = retrieve_learned_policy_from_record(
            learned_policies_record,
            transforms_mapping,
            step,
            ordered_ch_names,
            sfreq,
            random_state
        )

        # With the timer paused, assess the learned policy up to now
        results = searcher.learning_curve(
            windows_dataset=windows_dataset,
            epochs=assess_epochs,
            data_ratios=data_ratio,
            grouped_subset=grouped_subset,
            n_jobs=n_jobs,
            verbose=verbose,
        )

        # Add search duration and step information to results and save it
        results['step_idx'] = step
        results['step_search_duration'] = step_duration
        results['tot_search_duration'] = curr_duration
        results['tot_search_epochs'] = curr_search_epochs
        save_path = Path(training_dir) / 'search_perf_results.csv'
        if step == 0:
            results.to_csv(save_path, index=False)
        else:
            results.to_csv(
                save_path,
                index=False,
                header=False,
                mode='a'  # Append to the end of the file
            )


def training_params_from_args():
    parser = make_args_parser()
    parser.add_argument(
        "-l", "--subpolicy_length",
        type=int,
        default=2,
        help="Length of transforms sequence to sample."
    )

    parser.add_argument(
        "-pl", "--policy_length",
        type=int,
        default=5,
        help="Number of subpolicies in a policy."
    )

    parser.add_argument(
        "-t", "--n_trials", "--search_epochs",
        type=int, default=50,
        help="Total number of epochs of policy search."
    )

    parser.add_argument(
        "--eval_step", "--step",
        type=int, default=1,
        help="Number of search epochs between two policy evaluations."
    )

    parser.add_argument(
        "--best_transf_only",
        action="store_true",
        help="Whether to only use known best transforms.",
    )

    parser.add_argument(
        "--pretrain_path",
        help="Path to root of pretrained model to use for warmstarting."
             "(just for the model, not the policy)"
    )
    parser.add_argument(
        "-s", "--algo",
        help="Gradient-based algorithm for the search. Can be either 'match'"
             " for vanilla density matching or 'fraa' for Faster AutoAug."
    )
    parser.add_argument(
        "-cw", "--classwise",
        action="store_true",
        help="Whether to use classwise augmentation."
    )
    parser.add_argument(
        "--tr_pol_on", type=int,
        help="Whether to use classwise augmentation."
    )
    parser.add_argument(
        "--replay",
        action="store_true",
        help="Whether to replay retraining from former search records."
    )
    parser.add_argument(
        "--dir_to_replay",
        help="Where to get former search records."
    )
    parser.add_argument(
        "--ws_path",
        help="Path to main folder of pretrained ADDA checkpoint to use for"
             " warmstarting CADDA."
    )
    parser.add_argument(
        "--ws_epoch",
        type=int,
        help="Epoch of pretrained ADDA checkpoint to use for warmstarting."
    )
    parser.add_argument(
        "--grad_est",
        help="Whether to use simple softmax of operation, the Gumbel-softmax"
             " trick or RELAX."
    )
    parser.add_argument(
        "--dada",
        action="store_true",
        help="Whether to use DADA Policies"
    )
    parser.add_argument(
        "-plr", "--policy_lr",
        type=float, default=None,
        help="Learning rate to use for the policy gradient steps"
    )
    parser.add_argument(
        "--perturb",
        action="store_true",
        help="Whether to perturb policy params in DART."
    )
    args = parser.parse_args()

    windows_dataset, ch_names, sfreq = handle_dataset_args(args)

    tr_pol_on = args.tr_pol_on
    if tr_pol_on is None:
        tr_pol_on = "same"

    parameters = {
        'training_dir': args.training_dir,
        'windows_dataset': windows_dataset,
        'search_epochs': args.n_trials,
        'assess_epochs': args.epochs,
        'eval_step': args.eval_step,
        'sfreq': sfreq,
        'pretrain_base_path': args.pretrain_path,
        'device': args.device,
        'lr': args.lr,
        'batch_size': args.batch_size,
        'num_workers': args.num_workers,
        'random_state': args.random_state,
        'early_stop': args.early_stop,
        'n_folds': args.nfolds,
        'train_size_over_valid': args.train_size_over_valid,
        'model_to_use': args.model,
        'data_ratio': args.data_ratio,
        'grouped_subset': args.grouped_subset,
        'n_jobs': args.n_jobs,
        'ordered_ch_names': ch_names,
        'subpolicies_length': args.subpolicy_length,
        'policy_size_per_fold': args.policy_length,
        'use_transforms_subset': args.best_transf_only,
        'algo': args.algo,
        'classwise': args.classwise,
        'train_policy_on': tr_pol_on,
        "warmstart_base_path": args.ws_path,
        'warmstart_epoch': args.ws_epoch,
        'grad_est': args.grad_est,
        'dada': args.dada,
        'policy_lr': args.policy_lr,
        'perturb_policy': args.perturb,
    }

    if args.config:
        config = read_config(args.config)
        parameters.update(config["split"])
        parameters.update(config["diff"])
        parameters.update(config["policy"])
        parameters.update(config["training"])
    return parameters, args


if __name__ == "__main__":
    training_params, args = training_params_from_args()
    if not args.replay:
        evaluate_diff_policy_search(**training_params)
    else:
        replay_diff_policy_search_eval(
            dir_to_replay=args.dir_to_replay,
            **training_params
        )
