import shutil
import os
from os.path import join

import pytest
import numpy as np
import pandas as pd
from sklearn.utils import check_random_state

from braindecode.augmentation.transforms import TimeReverse

from eeg_augment.training_utils import load_preproc_mass
from eeg_augment.training_utils import load_preproc_physionet
from eeg_augment.training_utils import make_real_dataset_invariant
from eeg_augment.training_utils import CrossvalModel, prepare_training
from eeg_augment.training_utils import sample_invariant_decision_function
from eeg_augment.train import train_with_different_transforms, launch_training
from eeg_augment.auto_augmentation import PHYSIONET_BEST_TRANSFORMS


os.nice(5)


@pytest.mark.parametrize("dataset_to_load", ['mass', 'edf'])
def test_bids_loading(dataset_to_load):
    loading_func = {'mass': load_preproc_mass, 'edf': load_preproc_physionet}
    ds, _, _ = loading_func[dataset_to_load](n_subj=3)


@pytest.mark.parametrize("rand_n,transforms_collection", [
    (1, None),
    (2, None),
    (1, PHYSIONET_BEST_TRANSFORMS)
])
def test_randaugment(
    training_dir,
    small_real_dataset,
    rand_n,
    transforms_collection
):
    dataset, ch_names, sfreq = small_real_dataset
    train_with_different_transforms(
        experiment_dir=training_dir,
        epochs=3,
        windows_dataset=dataset,
        sfreq=sfreq,
        batch_size=16,
        n_folds=2,
        train_size_over_valid=0.5,
        transforms="randaugment",
        random_state=None,
        magnitudes=0.5,
        ordered_ch_names=ch_names,
        randaugment_seq_len=rand_n,
        randaugment_transform_collec=transforms_collection,
    )


@pytest.mark.parametrize("seed,transforms,n_jobs", [
    (None, None, 1),
    (None, ['no-aug', 'flip'], 1),
    (29, ['no-aug', 'flip'], 1),
    (29, ['no-aug', 'flip'], 2)
])
def test_multi_standard_pipeline(
    training_dir, small_real_dataset, seed, transforms, n_jobs
):
    train_with_different_transforms(
        experiment_dir=training_dir,
        epochs=3,
        windows_dataset=small_real_dataset[0],
        sfreq=100,
        batch_size=16,
        n_folds=2,
        train_size_over_valid=0.5,
        transforms=transforms,
        random_state=seed,
        n_jobs=n_jobs,
    )


@pytest.mark.parametrize("transforms", [None, 'flip', ['no-aug', 'flip']])
def test_enforced_invariance(rng_seed, training_dir, small_real_dataset,
                             transforms):
    invariant_decision_func = sample_invariant_decision_function(
        input_size=6000,
        enforce_inv='flip',
        random_state=rng_seed,
    )

    windows_dataset = make_real_dataset_invariant(
        small_real_dataset[0],
        invariant_decision_func,
    )
    train_with_different_transforms(
        experiment_dir=training_dir,
        epochs=3,
        windows_dataset=windows_dataset,
        sfreq=100,
        batch_size=16,
        n_folds=2,
        train_size_over_valid=0.5,
        transforms=transforms,
    )


@pytest.mark.parametrize("ds,model,transforms", [
    ("dummy", 'lin', TimeReverse(0.5)),
    (None, 'quad', TimeReverse(0.5)),
    (None, 'lin', 'no-aug')
])
def test_setup_exceptions(rng_seed, small_real_dataset, ds, model, transforms):
    if ds == "dummy":
        rng = check_random_state(rng_seed)
        ds = [(rng.randn(2, 10), 0, None)]
    if ds is None:
        ds, _, _ = small_real_dataset
    with pytest.raises(ValueError):
        launch_training(
            training_dir='temp_test_dir/',
            epochs=3,
            windows_dataset=ds,
            sfreq=100,
            batch_size=16,
            n_folds=2,
            train_size_over_valid=0.5,
            transforms=transforms,
            model_to_use=model,
            device='cpu',
        )


def test_dummy_ds_train(training_dir, dummy_dataset):
    train_with_different_transforms(
        experiment_dir=training_dir,
        epochs=3,
        windows_dataset=dummy_dataset,
        sfreq=100,
        batch_size=16,
        n_folds=2,
        train_size_over_valid=0.5,
        transforms=['no-aug', 'flip'],
        model_to_use='lin'
    )


@pytest.mark.parametrize("data_ratios,grouped_subset", [
    ('log2', True),
    ([0.5, 1.0], True),
    ('log2', False),
])
def test_data_ratios(
    training_dir,
    small_real_dataset,
    data_ratios,
    grouped_subset
):
    train_with_different_transforms(
        experiment_dir=training_dir,
        epochs=3,
        windows_dataset=small_real_dataset[0],
        sfreq=100,
        batch_size=16,
        n_folds=2,
        train_size_over_valid=0.5,
        transforms='flip',
        data_ratios=data_ratios,
        max_ratios=2,
        grouped_subset=grouped_subset,
    )


@pytest.mark.parametrize("probabilities,magnitudes,n_probas,n_mags", [
    (0.5, 0.5, None, None),
    ('lin', 0.5, 3, 3),
    (0.5, 'lin', 3, 3),
])
def test_magnitude_prob_grid(
    training_dir,
    small_real_dataset,
    probabilities,
    magnitudes,
    n_probas,
    n_mags
):
    train_with_different_transforms(
        experiment_dir=training_dir,
        epochs=3,
        windows_dataset=small_real_dataset[0],
        sfreq=100,
        batch_size=16,
        n_folds=2,
        train_size_over_valid=0.5,
        transforms='flip',
        probabilities=probabilities,
        magnitudes=magnitudes,
        n_probas=n_probas,
        n_mags=n_mags,
    )


def test_data_ratios_exception(small_real_dataset):
    with pytest.raises(ValueError):
        train_with_different_transforms(
            experiment_dir='temp_test_dir/',
            windows_dataset=small_real_dataset[0],
            epochs=3,
            sfreq=100,
            device='cpu',
            data_ratios=True,
        )


def test_splits_persistence(
    training_dir, small_real_dataset, rng_seed,
):
    ds, _, sfreq = small_real_dataset
    _, rng, model, model_params, shared_callbacks = prepare_training(
        windows_dataset=ds,
        sfreq=sfreq,
        random_state=rng_seed,
        batch_size=16,
        lr=1e-3,
        num_workers=4,
        early_stop=True,
        n_classes=5,
    )

    cross_val_training = CrossvalModel(
        training_dir,
        model,
        model_params=model_params,
        shared_callbacks=shared_callbacks,
        balanced_loss=True,  # Not settable for now
        monitor='valid_bal_acc_best',  # Not settable for now
        should_checkpoint=True,  # Not settable for now
        log_tensorboard=True,  # Not settable for now
        random_state=rng,
        n_folds=2,
        train_size_over_valid=0.5,
    )

    # Run two different consecutive trainings with the same object and store
    # the corresponding data splits
    splits = list()
    for _ in range(2):
        cross_val_training.learning_curve(
            windows_dataset=ds,
            epochs=2,
            data_ratios=None,
            n_jobs=2,
        )

        splits.append(cross_val_training.split_indices)

    # Check splits are the same
    for split0, split1 in zip(splits[0], splits[1]):
        for indices0, indices1 in zip(split0[2:], split1[2:]):
            assert all(indices0 == indices1)


@pytest.mark.parametrize("model_type,n_jobs,transform", [
    (None, 1, 'no-aug'),   # OK
    ('lin', 2, 'no-aug'),  # Ok
    ('lin', 2, 'flip'),    # Ok
    (None, 2, 'no-aug'),   # Ok
])
def test_training_reproducibility(
    rng_seed,
    training_dir,
    small_real_dataset,
    model_type,
    n_jobs,
    transform,
):
    dataset, ch_names, sfreq = small_real_dataset

    results = list()
    for i in range(2):
        training_dir_i = join(training_dir, f"run_{i}")
        save_path = join(training_dir_i, 'test_crossval_results.pkl')

        launch_training(
            training_dir=training_dir_i,
            epochs=1,
            windows_dataset=dataset,
            sfreq=sfreq,
            ordered_ch_names=ch_names,
            batch_size=16,
            n_jobs=n_jobs,
            random_state=rng_seed,
            n_folds=2,
            train_size_over_valid=0.5,
            data_ratios=0.1,
            grouped_subset=False,
            model_to_use=model_type,
            num_workers=4,
            n_classes=5,
            transforms=[(transform, 0.5, None)],
        )
        results.append(pd.read_pickle(save_path))

    for dataset in ["train", "valid", "test"]:
        for metric in ["bal_acc", "loss"]:
            label = f"{dataset}_{metric}"
            res0 = results[0][label].values
            res1 = results[1][label].values
            assert np.all((res0 - res1)**2 < 1e-6), f"Inconsistency in {label}"
    shutil.rmtree(training_dir)
