#
# License: BSD (3-clause)

import pytest
import tempfile

import torch
from sklearn.utils import check_random_state

from eeg_augment.sanity_check import AR1Dataset
from eeg_augment.training_utils import prep_physionet_dataset
from eeg_augment.training_utils import sample_invariant_decision_function

DEVICES = ['cpu']
if torch.cuda.is_available():
    DEVICES.append('cuda')


N_SUBJ_REAL_DS = 4
DUMMY_DS_LEN = 1024


def pytest_addoption(parser):
    parser.addoption("--seed", action="store", type=int, default=42)


@pytest.fixture
def rng_seed(pytestconfig):
    seed = pytestconfig.getoption("seed")
    return seed


@pytest.fixture(scope="function")
def training_dir():
    with tempfile.TemporaryDirectory() as directory:
        yield directory


@pytest.fixture(scope="session")
def small_real_dataset():
    return prep_physionet_dataset(
        n_subj=N_SUBJ_REAL_DS, recording_ids=[1],
        crop=(6000, 20_000))


@pytest.fixture
def dummy_dataset(rng_seed, dummy_ds_class=AR1Dataset):
    flip_invariance = sample_invariant_decision_function(
        input_size=20,
        enforce_inv='flip',
        n_classes=5,
        enforce_inv_std=0,
        random_state=rng_seed
    )

    ds_params = {
        'X_shape': (2, 10),
        'n_samples': DUMMY_DS_LEN,
        'invariant_op': flip_invariance,
    }
    if dummy_ds_class == AR1Dataset:
        ds_params.update(
            {
                'ar1_coefs': (-0.33, 0.4),
                'offsets': (0., 0.),
                'noise_std': (1.8, 0.2)
            }
        )

    return dummy_ds_class(**ds_params)


@pytest.fixture
def random_batch(rng_seed, batch_size=5):
    """ Generate batch of elements containing feature matrix of size 66x50
    filled with random floats between 0 and 1.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rng = check_random_state(rng_seed)
    X = torch.from_numpy(rng.random((batch_size, 66, 51))).float().to(device)
    return X, torch.zeros(batch_size)
