import os
from itertools import combinations
import random

import torch
import numpy as np
from sklearn.model_selection import GroupKFold
import pytest

from eeg_augment.utils import (
    find_device, log2_grid, linear_grid, grouped_split, get_groups,
    get_global_rngs_states, set_global_rngs_states
)
from eeg_augment.training_utils import _get_split_indices
from tests.conftest import N_SUBJ_REAL_DS, DUMMY_DS_LEN


os.nice(5)


@pytest.fixture(params=[1, 2])
def fake_groups(request):
    return np.hstack(
        ([i] * request.param for i in range(int(10//request.param)))
    )


def test_log2_grid():
    grid = log2_grid(8)
    expected = [0.25, 0.5, 1.]
    assert np.array_equal(grid, expected)


def test_linear_grid():
    grid = linear_grid(1, 5)
    expected = [0., 0.25, 0.5, 0.75, 1.]
    assert np.array_equal(grid, expected)


def basic_split_assertions(train_idx, test_idx, groups):
    # Here we assert the intersection of test and train is null
    assert len(set(train_idx) & set(test_idx)) == 0,\
        "Got non-empty intersetion in split."

    # The next check has no sense when text_idx is empty
    if len(test_idx) > 0:
        # Here we assert the intersection of groups in test and train is null
        assert len(set(groups[train_idx]) & set(groups[test_idx])) == 0,\
            "Got non-empty groups intersetion in split."


@pytest.mark.parametrize("ratio", [0.25, 0.5, 1.])
def test_grouped_split(rng_seed, fake_groups, ratio):
    indices = np.arange(10)
    train_idx, test_idx = grouped_split(
        indices,
        ratio,
        groups=fake_groups,
        random_state=rng_seed
    )
    basic_split_assertions(train_idx, test_idx, fake_groups)


@pytest.mark.parametrize("use_real_data,expected_groups_len", [
    (True, N_SUBJ_REAL_DS),
    (False, DUMMY_DS_LEN)
])
def test_get_groups(
    small_real_dataset,
    dummy_dataset,
    use_real_data,
    expected_groups_len
):
    if use_real_data:
        dataset = small_real_dataset[0]
    else:
        dataset = dummy_dataset
    groups = get_groups(dataset)
    assert len(np.unique(groups)) == expected_groups_len,\
        "Got groups of unexpected length."


def test_get_split_indices(small_real_dataset):
    groups = get_groups(small_real_dataset[0])
    split_indices = _get_split_indices(
        GroupKFold(n_splits=3),
        small_real_dataset[0],
        groups,
        0.5,
        data_ratios=[0.5, 1.],
        max_ratios=None
    )
    prev_fold = 0
    for split in split_indices:
        fold = split[0]
        for idx1, idx2 in combinations(split[2:], 2):
            basic_split_assertions(idx1, idx2, groups)
        # Check test and valid sets are conserved when subsetting training set
        valid_idx, test_idx = split[-2:]
        if fold == prev_fold:
            assert np.array_equal(valid_idx, prev_valid_idx),\
                "Validation set changed between subsets of the same fold!"
            assert np.array_equal(test_idx, prev_test_idx),\
                "Test set changed between subsets of the same fold!"
        else:
            prev_valid_idx = valid_idx
            prev_test_idx = test_idx


def test_global_rng_states_manip(rng_seed):
    device, cuda = find_device("cuda:1")
    set_global_rngs_states(seed=rng_seed, cuda=cuda)
    states_before_sampling = get_global_rngs_states(cuda=cuda)
    t = torch.randint(10, size=(10,))
    n = np.random.randint(10, size=10)
    p = [random.randint(0, 10) for _ in range(10)]
    set_global_rngs_states(
        states=states_before_sampling,
        cuda=cuda
    )
    assert torch.equal(t, torch.randint(10, size=(10,)))
    assert np.array_equal(n, np.random.randint(10, size=10))
    assert all([p[i] == random.randint(0, 10) for i in range(10)])
