import pytest

import torch
import torch.nn as nn

from nesim.losses.ring.loss import RingLoss1D

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

conv_layer_possible_params = {
    "in_channels": [64, 128],
    "out_channels": [15, 512, 1024],
    "kernel_size": [3, 5],
    "stride": [1, 2, 3],
    "freq_inner": [0],
    "freq_outer": [1, 2, 5, 10],
}


@pytest.mark.parametrize("in_channels", conv_layer_possible_params["in_channels"])
@pytest.mark.parametrize("out_channels", conv_layer_possible_params["out_channels"])
@pytest.mark.parametrize("kernel_size", conv_layer_possible_params["kernel_size"])
@pytest.mark.parametrize("stride", conv_layer_possible_params["stride"])
@pytest.mark.parametrize("freq_inner", conv_layer_possible_params["freq_inner"])
@pytest.mark.parametrize("freq_outer", conv_layer_possible_params["freq_outer"])
def test_conv(in_channels, out_channels, kernel_size, stride, freq_inner, freq_outer):
    """
    Test the following:
    1. get_loss() should return a tensor
    2. get_loss() should return a 0-D tensor
    """
    conv_layer = nn.Conv2d(
        in_channels, out_channels, kernel_size=kernel_size, stride=stride
    )

    loss_calculator = RingLoss1D(
        layer=conv_layer, freq_inner=freq_inner, freq_outer=freq_outer
    )
    loss = loss_calculator.get_loss()

    # 1
    assert torch.is_tensor(loss), "Expected loss to be a torch tensor"
    # 2
    assert len(loss.shape) == 0


@pytest.mark.parametrize("in_channels", conv_layer_possible_params["in_channels"])
@pytest.mark.parametrize("out_channels", [15, 64])
@pytest.mark.parametrize("kernel_size", conv_layer_possible_params["kernel_size"])
@pytest.mark.parametrize("stride", [1])
@pytest.mark.parametrize("n_training_steps", [10])
@pytest.mark.parametrize("freq_inner", conv_layer_possible_params["freq_inner"])
@pytest.mark.parametrize("freq_outer", conv_layer_possible_params["freq_outer"])
def test_training_loop(
    in_channels,
    out_channels,
    kernel_size,
    stride,
    n_training_steps,
    freq_inner,
    freq_outer,
):
    """
    Test the following:
    1. get_loss_fast() should return the same result as get_loss_original()
    2. on a small training loop, both losses should go down with very close values
    """
    conv_layer = nn.Conv2d(
        in_channels, out_channels, kernel_size=kernel_size, stride=stride
    )

    loss_calculator = RingLoss1D(
        layer=conv_layer, freq_inner=freq_inner, freq_outer=freq_outer
    )
    optimizer = torch.optim.Adam(conv_layer.parameters(), lr=1e-2)
    losses = []
    for train_step_idx in range(n_training_steps):
        optimizer.zero_grad()
        loss = loss_calculator.get_loss()
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

    assert losses[-1] < losses[0]
