#
# License: BSD (3-clause)

import os

import pytest
from sklearn.utils import check_random_state
import torch
from torch.fft import fft

from braindecode.augmentation.base import IdentityTransform

from eeg_augment.diff_aug.diff_transforms import *


os.nice(5)


def inner_learning_loop(diff_transform, ref_transform, optimizer, batch_size,
                        fft_loss, rng):
    batch = (
        torch.as_tensor(rng.randn(batch_size, 6, 300)).float(),
        torch.zeros(batch_size)
    )
    X_ref, _ = ref_transform(*batch)
    tr_X, _ = diff_transform(*batch)
    if fft_loss:
        loss = (torch.abs(fft(X_ref) - fft(tr_X))**2).mean()
    else:
        loss = ((X_ref - tr_X)**2).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss, diff_transform


def learning_transform(
    rng,
    diff_transform,
    ref_transform=IdentityTransform(1.0),
    epochs=100,
    batch_size=16,
    freeze_mag=False,
    freeze_proba=False,
    freeze_weights=True,
    lr=0.01,
    fft_loss=False,
):
    optimizer = torch.optim.SGD(diff_transform.parameters(), lr=lr)

    convergence = list()

    for i in range(epochs):
        loss, diff_transform = inner_learning_loop(
            diff_transform=diff_transform,
            ref_transform=ref_transform,
            optimizer=optimizer,
            batch_size=batch_size,
            fft_loss=fft_loss,
            rng=rng
        )

        mag = None
        if diff_transform.magnitude is not None:
            mag = diff_transform.magnitude.item()
        convergence.append(
            {
                'loss': loss.item(),
                'proba': diff_transform.probability.item(),
                'mag': mag,
                'iter': i
            }
        )

        if i % 10 == 0:
            loss, current = loss.item(), i * batch_size
            print_msg = (
                f"loss: {loss:>7f},"
                f"proba: {diff_transform.probability.item():>7f},"
            )
            if mag is not None:
                print_msg += f"mag: {mag:>7f},"
            if not freeze_proba:
                p_grad = diff_transform._probability.grad.item()
                print_msg += f"p_grad: {p_grad:>7f}, "
            if not freeze_mag and mag is not None:
                m_grad = diff_transform._magnitude.grad.item()
                print_msg += f"m_grad: {m_grad:>7f}, "

            if not freeze_weights:
                print_msg += f"w std: {diff_transform.weights.std():>7f}, "
            print_msg += f" [{current}/{batch_size * epochs}]"
            print(print_msg)
    return convergence


CH_NAMES_PARAMS = {
    'ordered_ch_names': ['C3', 'C4', 'F3', 'F4', 'O1', 'O2']
}


@pytest.mark.parametrize(
    "transform_class,params,epochs,eps,freeze_m,freeze_p,lr,fft_loss", [
        (DiffTimeReverse, None, 100, 1e-8, True, False, 0.01, False),
        (DiffSignFlip, None, 100, 1e-8, True, False, 0.01, False),
        (DiffChannelSymmetry, CH_NAMES_PARAMS, 100, 1e-8, True, False, 0.01, False),
        (DiffFTSurrogate, None, 100, 1e-8, True, False, 0.05, False),
        (DiffMissingChannels, None, 200, 1e-3, True, False, 0.01, False),
        (DiffFrequencyShift, None, 50, 1e-8, True, False, 0.01, True),
        (DiffGaussianNoise, None, 100, 1e-5, True, False, 10, False),
        (DiffTimeMask, None, 200, 1e-5, True, False, 0.1, False),
        (DiffShuffleChannels, None, 200, 1e-3, True, False, 0.01, False),
        (DiffRandomZRotation, CH_NAMES_PARAMS, 50, 5e-3, True, False, 1, False),
        (DiffRandomYRotation, CH_NAMES_PARAMS, 50, 5e-3, True, False, 1, False),
        (DiffRandomXRotation, CH_NAMES_PARAMS, 50, 5e-3, True, False, 1, False),
        # (DiffBandstopFilter, None, 100, 1e-5, True, False, 0.001, False), # not done yet
        (DiffFTSurrogate, None, 100, 1e-6, False, True, 0.05, False),
        (DiffMissingChannels, None, 200, 1e-3, False, True, 0.01, False),
        (DiffFrequencyShift, None, 100, 1e-8, False, True, 1e-4, True),
        (DiffGaussianNoise, None, 100, 1e-5, False, True, 10, False),
        (DiffTimeMask, None, 100, 1e-5, False, True, 0.1, False),
        (DiffShuffleChannels, None, 100, 1e-3, False, True, 1e8, False),
        (DiffRandomZRotation, CH_NAMES_PARAMS, 50, 5e-3, False, True, 1, False),
        (DiffRandomYRotation, CH_NAMES_PARAMS, 50, 5e-3, False, True, 1, False),
        (DiffRandomXRotation, CH_NAMES_PARAMS, 50, 5e-3, False, True, 1, False),
        #(DiffBandstopFilter, None, 100, 1e-5, False, True, 0.001, True),  # not done yet
    ])
def test_gradient_descent(
    rng_seed,
    transform_class,
    params,
    epochs,
    eps,
    freeze_m,
    freeze_p,
    lr,
    fft_loss,
):
    if params is None:
        params = {}

    rng = check_random_state(rng_seed)

    initial_probability = 0.5
    if freeze_p:
        initial_probability = 1.0

    diff_transform = transform_class(
        initial_probability=initial_probability,
        initial_magnitude=0.5,
        random_state=rng,
        **params
    )
    if freeze_m and diff_transform._magnitude is not None:
        diff_transform._magnitude.requires_grad = False
    if freeze_p:
        diff_transform._probability.requires_grad = False

    convergence = learning_transform(
        rng=rng,
        diff_transform=diff_transform,
        epochs=epochs,
        freeze_mag=freeze_m,
        freeze_proba=freeze_p,
        lr=lr,
        fft_loss=fft_loss,
    )

    final_loss = convergence[-1]['loss']
    assert final_loss < eps

    if freeze_m:
        final_param = convergence[-1]['proba']
    elif freeze_p:
        final_param = convergence[-1]['mag']
    else:
        final_param = min(convergence[-1]['mag'], convergence[-1]['proba'])

    assert final_param < eps
