import pytest

import torch
import torch.nn as nn

from nesim.losses.laplacian_pyramid.loss import LaplacianPyramidLoss
from nesim.utils.grid_size import find_rectangle_dimensions

conv_layer_possible_params = {
    "in_channels": [64, 128],
    "out_channels": [15, 512, 1024],
    "kernel_size": [3, 5],
    "stride": [1, 2, 3],
    "shrink_factor": [
        [2.0],
        [3.0],
    ],
}


@pytest.mark.parametrize("in_channels", conv_layer_possible_params["in_channels"])
@pytest.mark.parametrize("out_channels", conv_layer_possible_params["out_channels"])
@pytest.mark.parametrize("kernel_size", conv_layer_possible_params["kernel_size"])
@pytest.mark.parametrize("stride", conv_layer_possible_params["stride"])
@pytest.mark.parametrize("shrink_factor", conv_layer_possible_params["shrink_factor"])
def test_conv(in_channels, out_channels, kernel_size, stride, shrink_factor):
    """
    Test the following:
    1. get_loss() should return a tensor
    2. get_loss() should return a 0-D tensor
    3. grid should give us a tensor of shape: [size.height, size.width, *]
    """
    conv_layer = nn.Conv2d(
        in_channels, out_channels, kernel_size=kernel_size, stride=stride
    )
    size = find_rectangle_dimensions(area=out_channels)

    loss_calculator = LaplacianPyramidLoss(
        layer=conv_layer, factor_h=shrink_factor, factor_w=shrink_factor, device="cpu"
    )
    loss = loss_calculator.get_loss()

    # 1
    assert torch.is_tensor(loss), "Expected loss to be a torch tensor"
    # 2
    assert len(loss.shape) == 0
    # 3
    assert loss_calculator.grid_container.grid.shape[:2] == (size.height, size.width)
