import pytest

import torch
import torch.nn as nn

from nesim.losses.neighbourhood_cossim.loss import NeighbourhoodCosineSimilarityLoss
from nesim.utils.grid_size import find_rectangle_dimensions

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

linear_layer_possible_params = {
    "in_size": [64, 128],
    "out_size": [15, 512, 1024],
}


@pytest.mark.parametrize("in_size", linear_layer_possible_params["in_size"])
@pytest.mark.parametrize("out_size", linear_layer_possible_params["out_size"])
def test_linear(in_size, out_size):
    """
    Test the following:
    1. get_loss() should return a tensor
    2. get_loss() should return a 0-D tensor
    3. grid should give us a tensor of shape: [size.height, size.width, *]
    """
    linear_layer = nn.Linear(in_size, out_size)
    size = find_rectangle_dimensions(area=out_size)

    loss_calculator = NeighbourhoodCosineSimilarityLoss(
        layer=linear_layer, device=device
    )
    loss = loss_calculator.get_loss()

    # 1
    assert torch.is_tensor(loss), "Expected loss to be a torch tensor"
    # 2
    assert len(loss.shape) == 0
    # 3
    assert loss_calculator.grid_container.grid.shape[:2] == (size.height, size.width)


@pytest.mark.parametrize("in_size", linear_layer_possible_params["in_size"])
@pytest.mark.parametrize("out_size", linear_layer_possible_params["out_size"])
def test_fast_vs_original(in_size, out_size):
    """
    Test the following:
    1. get_loss_fast() should return the same result as get_loss_original()
    2. on a small training loop, both losses should go down with very close values
    """
    linear_layer = nn.Linear(in_size, out_size)

    loss_calculator = NeighbourhoodCosineSimilarityLoss(
        layer=linear_layer, device=device, padding=0, kernel_size=3, stride=1
    )
    loss_fast = loss_calculator.get_loss_fast()
    loss_original = loss_calculator.get_loss_original()

    assert torch.allclose(loss_fast, loss_original, atol=0.11) == True


@pytest.mark.parametrize("in_size", linear_layer_possible_params["in_size"])
@pytest.mark.parametrize("out_size", linear_layer_possible_params["out_size"])
@pytest.mark.parametrize("n_training_steps", [1, 2, 10, 20])
def test_training_loop(in_size, out_size, n_training_steps):
    """
    Test the following:
    1. get_loss_fast() should return the same result as get_loss_original()
    2. on a small training loop, both losses should go down with very close values
    """
    linear_layer = nn.Linear(in_size, out_size)

    loss_calculator = NeighbourhoodCosineSimilarityLoss(
        layer=linear_layer, device=device, padding=0, kernel_size=3, stride=1
    )
    optimizer = torch.optim.Adam(linear_layer.parameters(), lr=1e-2)

    for train_step_idx in range(n_training_steps):
        optimizer.zero_grad()
        loss = loss_calculator.get_loss_fast()
        with torch.no_grad():
            loss_original = loss_calculator.get_loss_original()
        loss.backward()
        optimizer.step()
        assert torch.allclose(loss, loss_original, atol=0.1) == True

    ## post training
    loss = loss_calculator.get_loss_fast()
    loss_original = loss_calculator.get_loss_original()
    assert torch.allclose(loss, loss_original, atol=0.1) == True
