import os

import pytest
from sklearn.utils import check_random_state

from braindecode.augmentation.transforms import TimeReverse

from eeg_augment.sanity_check import AR1Dataset, IIDGaussianDataset,\
    sanity_checker, make_invariant_function_from_transform


os.nice(5)


@pytest.mark.parametrize("device", [None, 'cpu'])
def test_AR1_training(rng_seed, training_dir, device):
    rng = check_random_state(rng_seed)
    flip_invariance = make_invariant_function_from_transform(
        TimeReverse(1.), 2, 10 * rng.random((4, 20)), rng.random(4), 0.
    )

    sanity_checker(
        AR1Dataset,
        flip_invariance,
        TimeReverse(0.5),
        X_shape=(2, 10),
        y_shape=4,
        n_folds=5,
        n_epochs=5,
        train_size=1024,
        valid_size=256,
        batch_size=64,
        save_path=training_dir,
        classif=True,
        ar1_coefs=(0.33, 0.4),
        offsets=(1., 1.),
        noise_std=(0.2, 0.2),
        patience=4,
        verbose=False,
        device=device
    )


def test_IIDGaussian_training(rng_seed, training_dir):
    rng = check_random_state(rng_seed)
    flip_invariance = make_invariant_function_from_transform(
        TimeReverse(1.), 2, rng.random((4, 20)), rng.random(4), 0.
    )

    sanity_checker(
        IIDGaussianDataset,
        flip_invariance,
        TimeReverse(0.5),
        X_shape=(2, 10),
        y_shape=4,
        n_folds=5,
        n_epochs=5,
        train_size=1024,
        valid_size=256,
        batch_size=128,
        save_path=training_dir,
        classif=True,
        patience=4,
        verbose=False,
        device='cpu',
        random_state=rng_seed,
    )
