
import numpy as np
import torch
from BAE.training import (fit_and_score_one_proportion, get_EEGClassifier,
                          parallel_learning_curve)
from BAE.utils import downsample, get_labels
from braindecode.augmentation import TimeReverse


def test_fit_and_score_one_proportion(
        PhysionetDataset,
        random_state,):
    rng_1 = np.random.RandomState(random_state)
    n_train = int(len(PhysionetDataset) * 4 / 5)
    train_valid_indices = np.arange(n_train)
    test_indices = np.arange(n_train, len(PhysionetDataset))

    clf_params_1 = {'iterator_train__transforms': [TimeReverse(
        probability=0.5,
        random_state=rng_1)]}
    clf_1 = get_EEGClassifier(
        clf_params=clf_params_1,
        random_state=random_state)
    scores_1 = fit_and_score_one_proportion(PhysionetDataset,
                                            train_valid_indices,
                                            test_indices,
                                            clf=clf_1,
                                            proportion=0.5,
                                            epochs=2,
                                            fold=None,
                                            random_state=random_state,
                                            fold_random_state=random_state,
                                            aug_magnitude=None,
                                            subjects_mask_train_valid=None)
    rng_2 = np.random.RandomState(random_state)
    clf_params_2 = {'iterator_train__transforms': [TimeReverse(
        probability=0.5,
        random_state=rng_2)]}
    clf_2 = get_EEGClassifier(
        clf_params=clf_params_2,
        random_state=random_state)
    scores_2 = fit_and_score_one_proportion(PhysionetDataset,
                                            train_valid_indices,
                                            test_indices,
                                            clf=clf_2,
                                            proportion=0.5,
                                            epochs=2,
                                            fold=None,
                                            random_state=random_state,
                                            fold_random_state=random_state,
                                            aug_magnitude=None,
                                            subjects_mask_train_valid=None)

    for i in range(len(scores_2)):
        print('\nnp.unique(y_pred): {}'.format(
            np.unique(scores_2.y_pred[i])))
        print(np.where(scores_2.y_pred[i] != scores_1.y_pred[i]))
        assert np.array_equal(scores_2.y_pred[i], scores_1.y_pred[i])


def test_downsample(PhysionetDataset, random_state):

    downsample_1, mask_1 = downsample(
        PhysionetDataset, random_state=random_state)
    labels_1 = get_labels(downsample_1)

    downsample_2, mask_2 = downsample(
        PhysionetDataset, random_state=random_state)
    labels_2 = get_labels(downsample_2)
    assert np.array_equal(labels_1, labels_2)
    assert np.array_equal(mask_1, mask_2)
    for i in range(20):
        assert np.array_equal(downsample_1[i][0], downsample_2[i][0])


def test_lr_curve(PhysionetDataset, random_state):

    downsample_1, mask_1 = downsample(
        PhysionetDataset, random_state=random_state)

    clf_params_1 = {'iterator_train__transforms': [TimeReverse(
        probability=0.5,
        random_state=random_state)]}
    clf_1 = get_EEGClassifier(
        clf_params=clf_params_1,
        random_state=random_state)
    score_1 = parallel_learning_curve(
        downsample_1,
        clf=clf_1,
        K=2,
        proportions=[0.5],
        epochs=2,
        n_jobs=1,
        random_state=random_state,
        subjects_mask=mask_1,)

    clf_params_2 = {'iterator_train__transforms': [TimeReverse(
        probability=0.5,
        random_state=random_state)]}
    clf_2 = get_EEGClassifier(
        clf_params=clf_params_2,
        random_state=random_state)

    downsample_2, mask_2 = downsample(
        PhysionetDataset, random_state=random_state)

    score_2 = parallel_learning_curve(
        downsample_2,
        clf=clf_2,
        K=2,
        proportions=[0.5],
        epochs=2,
        n_jobs=1,
        random_state=random_state,
        subjects_mask=mask_2,)

    print('pytest')
    for param1, param2 in zip(clf_1.module.feature_extractor.parameters(
    ), clf_2.module.feature_extractor.parameters()):
        assert torch.all(param1.eq(param2))

    for i in range(len(score_1)):
        assert np.array_equal(
            score_1.fold_random_state,
            score_2.fold_random_state)
        assert np.array_equal(score_1.y_pred[i], score_2.y_pred[i])
