import shutil
import pytest
from os.path import join

import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import GroupKFold
from torch.utils.data import Subset

from braindecode.augmentation.base import Compose
from braindecode.augmentation.transforms import SignFlip
from braindecode.augmentation.transforms import TimeMask
from braindecode.augmentation.transforms import TimeReverse
from braindecode.augmentation.transforms import ChannelSymmetry

from eeg_augment.train import launch_training
from eeg_augment.auto_augmentation import AugmentationPolicy
from eeg_augment.auto_augmentation import evaluate_discrete_policy_search
from eeg_augment.diff_aug.diff_auto_augmentation import (
    evaluate_diff_policy_search, DARTS
)

from .conftest import DEVICES
from eeg_augment.training_utils import prepare_training, _get_split_indices
from eeg_augment.utils import get_groups
from tests.diff_aug.test_diff_transforms import CH_NAMES_PARAMS


@pytest.mark.parametrize("batchwise", [False, True])
@pytest.mark.parametrize("device", DEVICES)
def test_policy_forward(rng_seed, batchwise, device):
    first_call = []
    for i in range(2):
        torch.manual_seed(rng_seed)
        rng = np.random.RandomState(rng_seed)
        aug_policy = AugmentationPolicy(
            [
                Compose([
                    TimeReverse(0.5, random_state=rng),
                    SignFlip(0.5, random_state=rng)
                ]),
                Compose([
                    TimeMask(0.5, 0.1, random_state=rng),
                    SignFlip(0.5, random_state=rng)
                ]),
                Compose([
                    ChannelSymmetry(0.5, magnitude=0.1,
                                    random_state=rng, **CH_NAMES_PARAMS),
                    SignFlip(0.5, random_state=rng)
                ])
            ], batchwise=batchwise, random_state=rng
        )
        X = torch.randn(16, 6, 3000, device=device)
        y = torch.ones(16, device=device)
        first_call.append(aug_policy(X, y))

    assert np.all(
        first_call[0][0].detach().cpu().numpy() ==
        first_call[1][0].detach().cpu().numpy()
    )


@pytest.mark.parametrize("sampler,metric", [
    ("random-search", "autoaug"),
    ("random-search", "match"),
    ("tpe", "match"),
])
def test_policy_search(
    rng_seed,
    training_dir,
    small_real_dataset,
    sampler,
    metric,
):
    dataset, ch_names, sfreq = small_real_dataset
    evaluate_discrete_policy_search(
        training_dir=training_dir,
        epochs=2,
        windows_dataset=dataset,
        sampler=sampler,
        metric=metric,
        sfreq=sfreq,
        subpolicies_length=2,
        n_trials=2,
        policy_size_per_fold=2,
        batch_size=16,
        n_folds=2,
        train_size_over_valid=0.5,
        random_state=rng_seed,
        ordered_ch_names=ch_names,
        n_jobs=2,
    )


@pytest.mark.parametrize("model_type,n_jobs,method,classwise", [
    (None, 1, "autoaug", False),  # Pass
    ('lin', 2, "match", False),   # Pass
    (None, 2, "autoaug", False),  # Pass
    (None, 2, "autoaug", True),  # Pass
    (None, 2, "match", False),    # Pass
])
def test_policy_search_reproducibility(
    rng_seed,
    training_dir,
    small_real_dataset,
    model_type,
    n_jobs,
    method,
    classwise,
):
    n_folds = 2
    data_ratio = 0.25
    tol = 1e-6
    model_name = "Chambon" if model_type is None else model_type
    dataset, ch_names, sfreq = small_real_dataset
    classes = np.unique(np.hstack([ds.y for ds in dataset.datasets]))

    # Do the same search and assess twice, and store search and train results
    results = list()
    learned_policies = list()
    for i in range(2):
        training_dir_i = join(training_dir, f"run_{i}")
        save_path = join(training_dir_i, 'search_perf_results.csv')

        evaluate_discrete_policy_search(
            training_dir=training_dir_i,
            epochs=1,
            windows_dataset=dataset,
            sampler="tpe",
            metric=method,
            sfreq=sfreq,
            subpolicies_length=1,
            n_trials=3,
            policy_size_per_fold=2,
            batch_size=16,
            n_folds=n_folds,
            train_size_over_valid=0.5,
            random_state=rng_seed,
            ordered_ch_names=ch_names,
            n_jobs=n_jobs,
            data_ratio=data_ratio,
            grouped_subset=False,
            model_to_use=model_type,
            classwise=classwise,
            classes=classes,
        )
        results.append(pd.read_csv(save_path))
        learned_policies_in_run_i = dict()
        for fold in range(1, 3):
            fold_path = join(training_dir_i, f'fold{fold}of{n_folds}')
            subset_path = join(fold_path, f'subset_{data_ratio}_samples')
            trials_path = join(subset_path, "trials.pkl")
            trials = pd.read_pickle(trials_path)
            learned_policies_in_run_i[fold] = trials.drop(
               columns=["datetime_start", "datetime_complete", "duration"]
            )
        learned_policies.append(learned_policies_in_run_i)

    # Check search results are the same
    for fold in range(1, 3):
        for column in learned_policies[0][fold]:
            if (
                "value" in column or
                "magnitude" in column or
                "probability" in column
            ):
                val1 = learned_policies[0][fold][column].values
                val2 = learned_policies[1][fold][column].values
                if val1.dtype == float:
                    assert np.all((val1 - val2)**2 < tol), (
                        f"Irreproducible policy search with model {model_name}"
                        f", {n_jobs} jobs and {method} metric."
                        f"{column} doesn't match for fold {fold}!"
                        f" {val1} != {val2}"
                    )
                else:
                    assert np.array_equal(val1, val2), (
                        f"Irreproducible policy search with model {model_name}"
                        f", {n_jobs} jobs and {method} metric."
                        f"{column} doesn't match for fold {fold}!"
                        f" {val1} != {val2}"
                    )

    # Check policies assessment results are the same
    for dataset in ["train", "valid", "test"]:
        for metric in ["bal_acc", "loss"]:
            label = f"{dataset}_{metric}"
            res0 = results[0][label].values
            res1 = results[1][label].values
            sq_err = (res0 - res1)**2
            assert np.all(sq_err < tol), (
                f"Irreproducible policy assessment with model {model_name}, "
                f"{n_jobs} jobs and {method} method. "
                f"Inconsistent {label} error: {sq_err}"
            )
    shutil.rmtree(training_dir)


@pytest.mark.parametrize("model_type,method,classwise", [
    ('lin', "match", False),
    (None, "autoaug", False),
    (None, "autoaug", True),
    (None, "match", False),
])
def test_step_size_robustness_in_discrete_search(
    rng_seed,
    training_dir,
    small_real_dataset,
    model_type,
    method,
    classwise,
):
    n_folds = 2
    data_ratio = 0.1
    tol = 1e-6
    model_name = "Chambon" if model_type is None else model_type
    n_jobs = 1
    dataset, ch_names, sfreq = small_real_dataset
    classes = np.unique(np.hstack([ds.y for ds in dataset.datasets]))

    # Do the same search and assess twice, and store search and train results
    results = list()
    learned_policies = list()
    for i in range(2):
        training_dir_i = join(training_dir, f"run_{i}")
        save_path = join(training_dir_i, 'search_perf_results.csv')

        eval_step = (i + 1) * 2

        evaluate_discrete_policy_search(
            training_dir=training_dir_i,
            epochs=1,
            n_trials=4,
            eval_step=eval_step,
            windows_dataset=dataset,
            sampler="tpe",
            metric=method,
            sfreq=sfreq,
            subpolicies_length=1,
            policy_size_per_fold=1,
            batch_size=16,
            n_folds=n_folds,
            train_size_over_valid=0.5,
            random_state=rng_seed,
            ordered_ch_names=ch_names,
            n_jobs=n_jobs,
            data_ratio=data_ratio,
            grouped_subset=False,
            model_to_use=model_type,
            classwise=classwise,
            classes=classes,
        )
        results.append(pd.read_csv(save_path))
        learned_policies_in_run_i = dict()
        for fold in range(1, 3):
            fold_path = join(training_dir_i, f'fold{fold}of{n_folds}')
            subset_path = join(fold_path, f'subset_{data_ratio}_samples')
            trials_path = join(subset_path, "trials.pkl")
            trials = pd.read_pickle(trials_path)
            learned_policies_in_run_i[fold] = trials.drop(
               columns=["datetime_start", "datetime_complete", "duration"]
            )
        learned_policies.append(learned_policies_in_run_i)

    # Check search results are the same
    for fold in range(1, 3):
        for column in learned_policies[0][fold]:
            if (
                "value" in column or
                "magnitude" in column or
                "probability" in column
            ):
                val1 = learned_policies[0][fold][column].values
                val2 = learned_policies[1][fold][column].values
                if val1.dtype == float:
                    assert np.all((val1 - val2)**2 < tol), (
                        f"Irreproducible policy search with model {model_name}"
                        f", {n_jobs} jobs and {method} metric."
                        f"{column} doesn't match for fold {fold}!"
                        f" {val1} != {val2}"
                    )
                else:
                    assert np.array_equal(val1, val2), (
                        f"Irreproducible policy search with model {model_name}"
                        f", {n_jobs} jobs and {method} metric."
                        f"{column} doesn't match for fold {fold}!"
                        f" {val1} != {val2}"
                    )

    # Check policies assessment results are the same
    tol = 1e-7
    max_trials = 4
    for dataset in ["train", "valid", "test"]:
        for metric in ["bal_acc", "loss"]:
            label = f"{dataset}_{metric}"
            res0 = results[0].query(
                "tot_trials == @max_trials")[label].values
            res1 = results[1].query(
                "tot_trials == @max_trials")[label].values
            sq_err = (res0 - res1)**2
            assert np.all(sq_err < tol), (
                "Policy search not robust to step size with model "
                f"{model_name} and {method} method. "
                f"Inconsistent {label} values {res0} and {res1}."
                f"Squared error {sq_err} > {tol}."
            )
    shutil.rmtree(training_dir)


@pytest.mark.parametrize("model_type,n_jobs,method,classwise,dada,perturb", [
    (None, 1, "darts", False, False, False),   # pass
    (None, 1, "fraa", False, False, False),    # pass
    ('lin', 2, "darts", False, False, False),  # pass
    (None, 2, "darts", False, False, False),   # pass
    (None, 2, "darts", True, False, False),    # pass
    (None, 2, "darts", False, True, False),    # pass
    (None, 2, "darts", False, False, True),    # pass
    (None, 2, "fraa", False, False, False),    # pass
])
def test_diff_policy_search_reproducibility(
    rng_seed,
    training_dir,
    small_real_dataset,
    model_type,
    n_jobs,
    method,
    classwise,
    dada,
    perturb,
):
    dataset, ch_names, sfreq = small_real_dataset
    n_folds = 2
    data_ratio = 0.1
    model_name = "Chambon" if model_type is None else model_type

    # Do the same search and assess twice, and store search and train results
    results = list()
    search_metrics = list()
    learned_policies = list()
    for i in range(2):
        training_dir_i = join(training_dir, f"run_{i}")
        save_path = join(training_dir_i, 'search_perf_results.csv')
        learned_policies_path = join(training_dir_i, 'learned_policies.csv')

        evaluate_diff_policy_search(
            training_dir=training_dir_i,
            search_epochs=1,
            assess_epochs=1,
            subpolicies_length=1,
            policy_size_per_fold=2,
            windows_dataset=dataset,
            sfreq=sfreq,
            ordered_ch_names=ch_names,
            batch_size=16,
            n_jobs=n_jobs,
            random_state=rng_seed,
            n_folds=n_folds,
            train_size_over_valid=0.5,
            data_ratio=data_ratio,
            grouped_subset=False,
            algo=method,
            model_to_use=model_type,
            classwise=classwise,
            dada=dada,
            perturb_policy=perturb,
        )
        results.append(pd.read_csv(save_path))
        search_metrics_in_run_i = dict()
        for fold in range(1, 3):
            fold_path = join(training_dir_i, f'fold{fold}of{n_folds}')
            subset_path = join(fold_path, f'subset_{data_ratio}_samples')
            history_path = join(subset_path, "search_history.csv")
            search_history = pd.read_csv(history_path)
            search_metrics_in_run_i[fold] = search_history.drop(
                columns=["time"]
            )
        search_metrics.append(search_metrics_in_run_i)
        learned_policies.append(pd.read_csv(learned_policies_path))

    # Check search selections are the same
    assert learned_policies[0]["operation"].equals(
        learned_policies[1]["operation"]
    ), (
        f"Irreproducible policy search with model {model_name}, "
        f"{n_jobs} jobs and {method} method. Selected operations don't match."
    )
    for label in ["probability", "magnitude"]:
        sq_err = (learned_policies[0][label] - learned_policies[1][label])**2
        assert np.all(sq_err < 1e-4), (
            f"Irreproducible policy search with model {model_name}, "
            f"{n_jobs} jobs and {method} method. Selected {label} don't match:"
            f" {sq_err}"
        )

    # Check search metrics are the same
    metric = "bal_acc"
    tol = 1e-6
    for fold in range(1, 3):
        for dataset in ["valid", "test"]:
            label = f"{dataset}_{metric}"
            res0 = search_metrics[0][fold][label].values
            res1 = search_metrics[1][fold][label].values
            sq_err = (res0 - res1)**2
            assert np.all(sq_err < tol), (
                f"Irreproducible policy search with model {model_name}, "
                f"{n_jobs} jobs and {method} method. Search metrics "
                f" {label} don't match: {res0} != {res1}"
            )

    # Check policies assessment results are the same
    for dataset in ["train", "valid", "test"]:
        for metric in ["bal_acc", "loss"]:
            label = f"{dataset}_{metric}"
            res0 = results[0][label].values
            res1 = results[1][label].values
            sq_err = (res0 - res1)**2
            assert np.all(sq_err < tol), (
                f"Irreproducible policy assessment with model {model_name}, "
                f"{n_jobs} jobs and {method} method. "
                f"Inconsistent {label} error: {sq_err}"
            )
    shutil.rmtree(training_dir)


@pytest.mark.parametrize("model_type,method,classwise", [
    (None, "darts", True),    # Pass
    (None, "darts", False),   # Pass
    (None, "fraa", False),    # Pass
    ('lin', "darts", False),  # Pass
])
def test_step_size_robustness_in_diff_search(
    rng_seed,
    training_dir,
    small_real_dataset,
    model_type,
    method,
    classwise,
):
    dataset, ch_names, sfreq = small_real_dataset
    n_folds = 2
    # data_ratio = 0.25  # The test passes with this number
    data_ratio = 0.1     # But not with this one
    model_name = "Chambon" if model_type is None else model_type
    n_jobs = 1

    # Do the same search and assess twice with a different number of search
    # epochs between assessments, and store search and train results
    results = list()
    learned_policies = list()
    search_metrics = list()
    for i in range(2):
        training_dir_i = join(training_dir, f"run_{i}")
        save_path = join(training_dir_i, 'search_perf_results.csv')
        learned_policies_path = join(training_dir_i, 'learned_policies.csv')

        # In the first run we do search > assess > search > assess
        # and in the second run we do search > search > assess
        eval_step = i + 1

        evaluate_diff_policy_search(
            training_dir=training_dir_i,
            search_epochs=2,
            assess_epochs=1,
            eval_step=eval_step,
            subpolicies_length=1,
            policy_size_per_fold=2,
            windows_dataset=dataset,
            sfreq=sfreq,
            ordered_ch_names=ch_names,
            batch_size=16,
            n_jobs=n_jobs,
            random_state=rng_seed,
            n_folds=n_folds,
            train_size_over_valid=0.5,
            data_ratio=data_ratio,
            grouped_subset=False,
            algo=method,
            model_to_use=model_type,
            grad_est="gumbel",
            classwise=classwise,
        )
        results.append(pd.read_csv(save_path))
        learned_policies.append(pd.read_csv(learned_policies_path))
        search_metrics_in_run_i = dict()
        for fold in range(1, 3):
            fold_path = join(training_dir_i, f'fold{fold}of{n_folds}')
            subset_path = join(fold_path, f'subset_{data_ratio}_samples')
            history_path = join(subset_path, "search_history.csv")
            search_history = pd.read_csv(history_path)
            search_metrics_in_run_i[fold] = search_history.drop(
                columns=["time"]
            )
        search_metrics.append(search_metrics_in_run_i)

    # Only learned policies used for the last assessment of each run
    max_step_idx = 1
    learned_policies[0] = learned_policies[0].query(
        "step_idx == @max_step_idx"
    ).reset_index()

    # Check search selections are the same
    assert learned_policies[0]["operation"].equals(
        learned_policies[1]["operation"]
    ), (
        f"Policy search not robust to step size with model {model_name} "
        f"and {method} method. Selected operations don't match."
    )
    tol = 1e-7
    for label in ["probability", "magnitude"]:
        sq_err = (learned_policies[0][label] - learned_policies[1][label])**2
        assert np.all(sq_err < tol), (
            f"Policy search not robust to step size with model {model_name} "
            f"and {method} method. Selected {label} don't match: {sq_err}"
        )

    # Check search metrics are the same
    metric = "bal_acc"
    for fold in range(1, 3):
        for dataset in ["valid", "test"]:
            label = f"{dataset}_{metric}"
            res0 = search_metrics[0][fold][label].values
            res1 = search_metrics[1][fold][label].values
            sq_err = (res0 - res1)**2
            assert np.all(sq_err < tol), (
                f"Irreproducible policy search with model {model_name}, "
                f"{n_jobs} jobs and {method} method. Search metrics "
                f" {label} don't match: {res0} != {res1}"
            )

    # Check policies assessment results are the same
    max_search_epochs = 2
    for dataset in ["train", "valid", "test"]:
        for metric in ["bal_acc", "loss"]:
            label = f"{dataset}_{metric}"
            res0 = results[0].query(
                "tot_search_epochs == @max_search_epochs")[label].values
            res1 = results[1].query(
                "tot_search_epochs == @max_search_epochs")[label].values
            sq_err = (res0 - res1)**2
            assert np.all(sq_err < tol), (
                "Policy search not robust to step size with model "
                f"{model_name} and {method} method. "
                f"Inconsistent {label} values {res0} and {res1}."
                f"Squared error {sq_err} > {tol}."
            )
    shutil.rmtree(training_dir)


def test_cadda_warmstarting(
    rng_seed,
    training_dir,
    small_real_dataset,
):
    # Unwrap dataset info and prepare shared model and params
    windows_dataset, ordered_ch_names, sfreq = small_real_dataset
    _, _, model, model_params, _ = prepare_training(
        windows_dataset,
        lr=0.001,
        batch_size=16,
        num_workers=4,
        device="cpu",
        early_stop=True,
        sfreq=sfreq,
        n_classes=5,
        model_to_use=None,
        random_state=rng_seed,
    )

    tr_val_ratio = 0.5
    n_folds = 2

    # Instantiante an ADDA searcher
    adda_path = join(training_dir, "adda-test")
    adda = DARTS(
        training_dir=adda_path,
        model=model,
        subpolicies_length=2,
        policy_size_per_fold=5,
        transforms_family=None,
        network_momentum=0.9,
        network_weight_decay=3e-4,
        unrolled=True,
        classwise=False,
        random_state=None,
        model_params=model_params,
        n_folds=n_folds,
        shared_callbacks=None,
        balanced_loss=True,
        monitor='valid_loss_best',
        should_checkpoint=True,
        should_load_state=True,
        log_tensorboard=True,
        train_size_over_valid=tr_val_ratio,
    )

    # Use it to access the first split of the dataset
    kf = GroupKFold(n_splits=adda.n_folds)
    groups = get_groups(windows_dataset)
    split_indices = _get_split_indices(
        cv=kf,
        windows_dataset=windows_dataset,
        groups=groups,
        train_size_over_valid=adda.train_size_over_valid,
        data_ratios=[0.1],
        max_ratios=None,
        grouped_subset=False,
        random_state=adda.splitting_random_state,
    )

    split = split_indices[0]
    fold, _, train_subset_idx, valid_idx, test_idx = split

    # Prepare valid weighted loss and data loader
    train_set = Subset(windows_dataset, train_subset_idx)
    valid_set = Subset(windows_dataset, valid_idx)
    test_set = Subset(windows_dataset, test_idx)

    _, _, valid_eval_criterion, _ = adda.instantiate_criteria(
        train_set, valid_set, test_set
    )

    _, valid_loader, _ = adda.make_search_valid_loaders(
        train_set, valid_set, test_set
    )

    # Carry a single training epoch with ADDA
    adda._search_in_fold(
        split=split,
        random_state=rng_seed,
        epochs=1,
        windows_dataset=windows_dataset,
        warmstart_base_path=None,
        ordered_ch_names=ordered_ch_names,
        sfreq=sfreq,
        warmstart_epoch=None,
    )

    # Make sure the global RNGs are controlled
    adda._init_global_and_specific_rngs(
        fold=fold,
        random_state=rng_seed,
    )

    # And compute the valid loss and accuracy
    adda_valid_loss, adda_valid_acc = adda._test(
        model=model,
        test_loader=valid_loader,
        criterion=valid_eval_criterion,
        valid=True
    )

    # Now instantiate a comparable CADDA searcher
    cadda = DARTS(
        training_dir=join(training_dir, "cadda-test"),
        model=model,
        subpolicies_length=2,
        policy_size_per_fold=5,
        transforms_family=None,
        network_momentum=0.9,
        network_weight_decay=3e-4,
        unrolled=True,
        classwise=True,
        random_state=None,
        model_params=model_params,
        n_folds=n_folds,
        shared_callbacks=None,
        balanced_loss=True,
        monitor='valid_loss_best',
        should_checkpoint=True,
        should_load_state=True,
        log_tensorboard=True,
        train_size_over_valid=tr_val_ratio,
    )

    # Do all the steps in the search preceeding the epochs loop
    # (indeed, notice that epochs=0 here)
    # so that the previous ADDA checkpoint is used for warmstarting
    cadda._search_in_fold(
        split=split,
        random_state=rng_seed,
        epochs=0,                           # <--
        windows_dataset=windows_dataset,
        warmstart_base_path=adda_path,      # <--
        ordered_ch_names=ordered_ch_names,
        sfreq=sfreq,
        warmstart_epoch=1,                  # <--
    )

    # Once again, set the globel RNGs to control stochasticity of the policy
    cadda._init_global_and_specific_rngs(
        fold=fold,
        random_state=rng_seed,
    )

    # And compute valid loss and accuracy again
    cadda_valid_loss, cadda_valid_acc = adda._test(
        model=model,
        test_loader=valid_loader,
        criterion=valid_eval_criterion,
        valid=True
    )

    # Check that we get the same results
    assert cadda_valid_loss == adda_valid_loss
    assert cadda_valid_acc == adda_valid_acc


def test_cadda_warmstarting_full_run(
    rng_seed,
    training_dir,
    small_real_dataset,
):
    dataset, ch_names, sfreq = small_real_dataset
    n_folds = 2
    data_ratio = 0.1
    model_type = None
    n_jobs = 2
    method = "darts"

    # First search and eval ADDA
    training_dir_adda = join(training_dir, f"adda")
    evaluate_diff_policy_search(
        training_dir=training_dir_adda,
        search_epochs=2,
        assess_epochs=1,
        subpolicies_length=1,
        policy_size_per_fold=2,
        windows_dataset=dataset,
        sfreq=sfreq,
        ordered_ch_names=ch_names,
        batch_size=16,
        n_jobs=n_jobs,
        random_state=rng_seed,
        n_folds=n_folds,
        train_size_over_valid=0.5,
        data_ratio=data_ratio,
        grouped_subset=False,
        algo=method,
        model_to_use=model_type,
        classwise=False,
    )

    # Then do the same with CADDA, warmstarting with ADDA
    training_dir_cadda = join(training_dir, f"cadda")
    evaluate_diff_policy_search(
        training_dir=training_dir_cadda,
        search_epochs=2,
        assess_epochs=1,
        subpolicies_length=1,
        policy_size_per_fold=2,
        windows_dataset=dataset,
        sfreq=sfreq,
        ordered_ch_names=ch_names,
        batch_size=16,
        n_jobs=n_jobs,
        random_state=rng_seed,
        n_folds=n_folds,
        train_size_over_valid=0.5,
        data_ratio=data_ratio,
        grouped_subset=False,
        algo=method,
        model_to_use=model_type,
        classwise=True,
        warmstart_base_path=training_dir_adda,
        warmstart_epoch=2,
    )
    shutil.rmtree(training_dir)


def test_weightsharing(
    rng_seed,
    training_dir,
    small_real_dataset,
):
    dataset, ch_names, sfreq = small_real_dataset
    n_folds = 2
    data_ratio = 0.1
    n_jobs = 1
    batch_size = 16

    # Train a model without augmentation for warmstarting the search
    training_dir_ws = join(training_dir, f"ws")
    launch_training(
        training_dir=training_dir_ws,
        epochs=1,
        windows_dataset=dataset,
        sfreq=sfreq,
        batch_size=batch_size,
        ordered_ch_names=ch_names,
        data_ratio=data_ratio,
        grouped_subset=False,
        n_jobs=n_jobs,
        random_state=rng_seed,
        n_folds=n_folds,
        train_size_over_valid=0.5,
    )

    training_dir_search = join(training_dir, f"search")
    evaluate_diff_policy_search(
        training_dir=training_dir_search,
        search_epochs=1,
        assess_epochs=1,
        subpolicies_length=1,
        policy_size_per_fold=2,
        windows_dataset=dataset,
        sfreq=sfreq,
        ordered_ch_names=ch_names,
        batch_size=batch_size,
        n_jobs=n_jobs,
        random_state=rng_seed,
        n_folds=n_folds,
        train_size_over_valid=0.5,
        data_ratio=data_ratio,
        grouped_subset=False,
        algo="darts",
        pretrain_base_path=training_dir_ws,
    )
