import os
from functools import partial

import mne
from numpy import random
import torch
import pytest
import numpy as np
from torch import nn
from sklearn.utils import check_random_state

from braindecode.datautil import create_from_mne_epochs
from braindecode.augmentation.transforms import TimeReverse
from braindecode.augmentation.base import IdentityTransform
from braindecode.augmentation.base import Compose, BaseDataLoader

from eeg_augment.diff_aug.base import SubpolicyStage
from eeg_augment.diff_aug.base import DiffSubpolicy, DiffAugmentationPolicy

from eeg_augment.diff_aug.diff_transforms import DiffSignFlip
from eeg_augment.diff_aug.diff_transforms import DiffTimeMask
from eeg_augment.diff_aug.diff_transforms import DiffTransform
from eeg_augment.diff_aug.diff_transforms import DiffTimeReverse
from eeg_augment.diff_aug.diff_transforms import DiffGaussianNoise
from eeg_augment.diff_aug.diff_transforms import DiffChannelSymmetry
from eeg_augment.diff_aug.diff_transforms import DiffRandomXRotation
from eeg_augment.diff_aug.diff_transforms import DiffRandomYRotation
from eeg_augment.diff_aug.diff_transforms import DiffRandomZRotation

from tests.diff_aug.test_diff_transforms import CH_NAMES_PARAMS
from tests.conftest import DEVICES

os.nice(5)


def dummy_k_operation(X, y, k, *args, **kwargs):
    return torch.ones_like(X) * k, y


def common_tranform_assertions(input_batch, output_batch, expected_X=None,
                               eps=1e-5):
    """ Assert whether shapes and devices are conserved. Also, (optional)
    checks whether the expected features matrix is produced.

    Parameters
    ----------
    input_batch : tuple
        The batch given to the transform containing a tensor X, of shape
        (batch_sizze, n_channels, sequence_len), and a tensor  y of shape
        (batch_size).
    output_batch : tuple
        The batch output by the transform. Should have two elements: the
        transformed X and y.
    expected_X : tensor, optional
        The expected first element of output_batch, which will be compared to
        it. By default None.
    """
    X, y = input_batch
    tr_X, tr_y = output_batch
    assert tr_X.shape == X.shape
    assert tr_X.shape[0] == tr_y.shape[0]
    assert torch.equal(tr_y, y)
    assert X.device == tr_X.device
    if expected_X is not None:
        assert torch.abs(tr_X - expected_X).max() < eps


@pytest.mark.parametrize("k1,k2,expected", [
    (1, 0, 0),
    (0, 1, 1)
])
def test_transform_composition(random_batch, k1, k2, expected):
    X, y = random_batch
    dummy_transform1 = DiffTransform(partial(dummy_k_operation, k=k1), 1)
    dummy_transform2 = DiffTransform(partial(dummy_k_operation, k=k2), 1)
    concat_transform = Compose([dummy_transform1, dummy_transform2])
    expected_tensor = torch.ones(
        X.shape,
        device=X.device
    ) * expected

    common_tranform_assertions(
        random_batch,
        concat_transform(*random_batch),
        expected_tensor
    )


@pytest.mark.parametrize("probability", [0, 1])
def test_transform_with_kwargs(random_batch, rng_seed, probability):
    rng = check_random_state(rng_seed)
    k = rng.randint(10)
    dummy_transform = DiffTransform(
        dummy_k_operation,
        initial_probability=probability,
        k=k
    )
    X, y = random_batch
    expected_tensor = torch.ones(
        X.shape,
        device=X.device
    ) * k if probability else X
    common_tranform_assertions(
        random_batch,
        dummy_transform(*random_batch),
        expected_tensor
    )


def test_transform_proba_exception(random_batch, rng_seed):
    with pytest.raises(AssertionError):
        rng = check_random_state(rng_seed)
        k = rng.randint(10)
        _ = DiffTransform(
            dummy_k_operation,
            'a',
            k=k
        )


@pytest.fixture(scope="module")
def concat_windows_dataset():
    """Generates a small BaseConcatDataset out of WindowDatasets extracted
    from the physionet database.
    """
    subject_id = 22
    event_codes = [5, 6, 9, 10, 13, 14]
    physionet_paths = mne.datasets.eegbci.load_data(
        subject_id, event_codes, update_path=False)

    parts = [mne.io.read_raw_edf(path, preload=True, stim_channel='auto')
             for path in physionet_paths]
    list_of_epochs = [mne.Epochs(raw, [[0, 0, 0]], tmin=0, baseline=None)
                      for raw in parts]
    windows_datasets = create_from_mne_epochs(
        list_of_epochs,
        window_size_samples=50,
        window_stride_samples=50,
        drop_last_window=False
    )

    return windows_datasets


def dummy_transform():
    k = np.random.randint(10)
    return DiffTransform(dummy_k_operation, 1, k=k)


# test BaseDataLoader with 0, 1 and 2 composed transforms
@pytest.mark.parametrize("nb_transforms,no_list", [
    (0, False), (1, False), (1, True), (2, False)
])
def test_data_loader(concat_windows_dataset, nb_transforms, no_list):
    transforms = [dummy_transform() for _ in range(nb_transforms)]
    if no_list:
        transforms = transforms[0]
    data_loader = BaseDataLoader(
        concat_windows_dataset,
        transforms=transforms,
        batch_size=128
    )
    for idx_batch, (batch_x, batch_y) in enumerate(data_loader):
        if idx_batch >= 3:
            break


def test_data_loader_exception(concat_windows_dataset):
    with pytest.raises(TypeError):
        _ = BaseDataLoader(
            concat_windows_dataset,
            transforms='a',
            batch_size=128
        )


def learn_stage(
    seed,
    diff_transform_classes_and_params,
    ref_transform=IdentityTransform(1.0),
    epochs=100,
    batch_size=16,
    freeze_mag=False,
    freeze_proba=False,
    freeze_weights=False,
    lr=0.01,
    make_subpolicy=False,
    grad_est=None,
    device="cpu",
):
    rng = check_random_state(seed)
    torch.manual_seed(seed)

    # Instantiate inner Transforms and freeze parameters if needed
    operations = list()
    for diff_transform_class, params in diff_transform_classes_and_params:
        initial_probability = 0.5
        if freeze_proba:
            initial_probability = 1.0

        diff_transform = diff_transform_class(
            initial_probability=initial_probability,
            initial_magnitude=0.5,
            random_state=rng,
            **params
        )
        if freeze_mag and diff_transform._magnitude is not None:
            diff_transform._magnitude.requires_grad = False
        if freeze_proba:
            diff_transform._probability.requires_grad = False

        operations.append(diff_transform)
    operations = nn.ModuleList(operations)

    # Use the new transforms to create the SubpolicyStage object
    subpolicy = SubpolicyStage(
        operations,
        temperature=0.05,
        grad_est=grad_est,
        random_state=seed,
    ).to(device)
    if freeze_weights:
        subpolicy._weights.requires_grad = False

    if make_subpolicy:
        subpolicy = DiffSubpolicy(
            subpolicy_stages=[subpolicy],
        ).to(device)

    optimizer = torch.optim.SGD(subpolicy.parameters(), lr=lr)

    convergence = list()

    for i in range(epochs):
        batch = (
            torch.as_tensor(
                rng.randn(batch_size, 6, 300),
                device=device,
            ).float(),
            torch.zeros(batch_size, device=device,)
        )
        X_ref, _ = ref_transform(*batch)
        tr_X, _ = subpolicy(*batch)
        loss = ((X_ref - tr_X)**2).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        probabilities = subpolicy.all_probabilities
        magnitudes = subpolicy.all_magnitudes
        if make_subpolicy:
            weights = subpolicy.all_weights
        else:
            weights = subpolicy.weights

        convergence.append(
            {
                'loss': loss.item(),
                'operations_weights': weights,
                'operations_probabilities': probabilities,
                'operations_magnitudes': magnitudes,
                'iter': i
            }
        )

        if i % 10 == 0:
            loss, current = loss.item(), i * batch_size
            print_msg = (
                f"loss: {loss:>7f},"
                f"av. proba: {np.mean(probabilities):>7f},"
                f"av. mag: {np.mean(magnitudes):>7f},"
            )
            if not freeze_proba:
                probabilities_grad = subpolicy.all_prob_grads
                print_msg += f"av. p_grad: {np.mean(probabilities_grad):>7f}, "
            if not freeze_mag:
                magnitudes_grad = subpolicy.all_mag_grads
                print_msg += f"av. m_grad: {np.mean(magnitudes_grad):>7f}, "

            print_msg += f"weights_std: {weights.std():>7f}, "
            print_msg += f" [{current}/{batch_size * epochs}]"
            print(print_msg)
    return convergence


@pytest.mark.parametrize('use_subpolicy', [False, True])
@pytest.mark.parametrize('grad_est', [None, "gumbel", "relax"])
@pytest.mark.parametrize('device', DEVICES)
def test_subpolicy_proba(rng_seed, use_subpolicy, grad_est, device):
    operations = [
        (DiffTimeReverse, {}),
        (DiffSignFlip, {}),
        (DiffChannelSymmetry, CH_NAMES_PARAMS)
    ]
    convergence = learn_stage(
        seed=rng_seed,
        diff_transform_classes_and_params=operations,
        epochs=500,
        freeze_mag=True,
        freeze_proba=False,
        freeze_weights=True,
        lr=0.01,
        make_subpolicy=use_subpolicy,
        grad_est=grad_est,
        device=device,
    )
    final_loss = convergence[-1]['loss']
    assert final_loss < 1e-8

    final_param = np.mean(convergence[-1]['operations_probabilities'])
    assert final_param < 1e-4


@pytest.mark.parametrize("use_subpolicy", [False, True])
@pytest.mark.parametrize("grad_est", [None, "gumbel", "relax"])
def test_subpolicy_stage_mag(rng_seed, use_subpolicy, grad_est):
    operations = [
        (DiffGaussianNoise, {}),
        (DiffTimeMask, {}),
        (DiffRandomZRotation, CH_NAMES_PARAMS),
        (DiffRandomYRotation, CH_NAMES_PARAMS),
        (DiffRandomXRotation, CH_NAMES_PARAMS),
    ]

    convergence = learn_stage(
        seed=rng_seed,
        diff_transform_classes_and_params=operations,
        epochs=100,
        freeze_mag=False,
        freeze_proba=True,
        freeze_weights=True,
        lr=10,
        make_subpolicy=use_subpolicy,
        grad_est=grad_est,
    )

    final_param = np.mean(convergence[-1]['operations_magnitudes'])
    assert final_param < 1e-6


@pytest.mark.parametrize("grad_est", [None, "gumbel", "relax"])
def test_subpolicy_stage_weights_reverse(rng_seed, grad_est):
    operations = [
        (DiffTimeMask, {}),
        (DiffRandomZRotation, CH_NAMES_PARAMS),
        (DiffTimeReverse, {}),
        (DiffRandomXRotation, CH_NAMES_PARAMS),
    ]
    solution_weights = torch.Tensor([0., 0., 1., 0.])

    correct_proba = 1.0
    ref_transform = TimeReverse(probability=correct_proba, random_state=42)

    convergence = learn_stage(
        seed=rng_seed,
        diff_transform_classes_and_params=operations,
        ref_transform=ref_transform,
        epochs=30,
        freeze_mag=False,
        freeze_proba=False,
        freeze_weights=False,
        lr=0.2,
        grad_est=grad_est,
    )
    final_loss = convergence[-1]['loss']
    assert final_loss < 1e-6

    # check we select the right transform
    assert torch.sum(
        (solution_weights - convergence[-1]["operations_weights"].detach())**2
    ) < 1e-6

    # check we match its probability
    assert np.abs(
        convergence[-1]['operations_probabilities'][2] - correct_proba
    ) < 1e-6


@pytest.mark.parametrize('device', DEVICES)
def test_diff_policy_forward(rng_seed, device):
    torch.manual_seed(rng_seed)
    aug_policy = DiffAugmentationPolicy(
        n_subpolicies=5,
        subpolicy_len=2,
        ch_names=CH_NAMES_PARAMS["ordered_ch_names"],
        sfreq=100,
        random_state=rng_seed,
    )
    fake_X = torch.randn(16, 6, 3000, device=device)
    fake_y = torch.ones(16, device=device)
    _ = aug_policy(fake_X, fake_y)
