import pytest
import torch
import torch.nn as nn
from nesim.losses.cross_layer_correlation.loss import (
    CrossLayerCorrelationLoss,
    SinglePairHookContainer,
)

linear_layer_possible_params = {
    "in_size": [1, 64],
    "hidden_size": [32, 64],
    "out_size": [10, 32, 64],
}

conv_layer_possible_params = {
    "in_size": [1, 16],
    "hidden_size": [32, 64],
    "out_size": [32, 64],
    "kernel_size": [1, 3],
}

batch_sizes = [
    1,
    2,
]

input_heights = [10, 32]

input_widths = [10, 32]


class FlattenAlongHeightWidth(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.mean(-1).mean(-1)


@pytest.mark.parametrize("in_size", linear_layer_possible_params["in_size"])
@pytest.mark.parametrize("out_size", linear_layer_possible_params["out_size"])
@pytest.mark.parametrize("hidden_size", linear_layer_possible_params["hidden_size"])
@pytest.mark.parametrize("batch_size", batch_sizes)
def test_on_linear(in_size, out_size, hidden_size, batch_size):
    model = nn.Sequential(
        nn.Linear(in_size, hidden_size),
        nn.Linear(hidden_size, hidden_size),
        nn.Linear(hidden_size, out_size),
    )

    cross_layer_correlation_loss = CrossLayerCorrelationLoss(
        pair_hook_containers=[
            SinglePairHookContainer(layer1=model[1], layer2=model[2], scale=1.0)
        ]
    )

    input_tensor = torch.randn(batch_size, in_size)

    # forward pass
    output = model(input_tensor)

    loss = cross_layer_correlation_loss.compute()
    assert torch.is_tensor(loss)


@pytest.mark.parametrize("in_size", conv_layer_possible_params["in_size"])
@pytest.mark.parametrize("out_size", conv_layer_possible_params["out_size"])
@pytest.mark.parametrize("hidden_size", conv_layer_possible_params["hidden_size"])
@pytest.mark.parametrize("kernel_size", conv_layer_possible_params["kernel_size"])
@pytest.mark.parametrize("input_height", input_heights)
@pytest.mark.parametrize("input_width", input_widths)
@pytest.mark.parametrize("batch_size", batch_sizes)
def test_on_conv(
    in_size, out_size, hidden_size, batch_size, kernel_size, input_height, input_width
):
    model = nn.Sequential(
        nn.Conv2d(in_size, hidden_size, kernel_size=kernel_size),
        nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size),
        nn.Conv2d(hidden_size, out_size, kernel_size=kernel_size),
    )

    cross_layer_correlation_loss = CrossLayerCorrelationLoss(
        pair_hook_containers=[
            SinglePairHookContainer(layer1=model[1], layer2=model[2], scale=1.0)
        ]
    )

    input_tensor = torch.randn(batch_size, in_size, input_height, input_width)

    # forward pass
    output = model(input_tensor)

    loss = cross_layer_correlation_loss.compute()
    assert torch.is_tensor(loss)


@pytest.mark.parametrize("in_size", conv_layer_possible_params["in_size"])
@pytest.mark.parametrize("out_size", conv_layer_possible_params["out_size"])
@pytest.mark.parametrize("hidden_size", conv_layer_possible_params["hidden_size"])
@pytest.mark.parametrize("kernel_size", conv_layer_possible_params["kernel_size"])
@pytest.mark.parametrize("input_height", input_heights)
@pytest.mark.parametrize("input_width", input_widths)
@pytest.mark.parametrize("batch_size", batch_sizes)
def test_on_conv_and_linear_mix(
    in_size, out_size, hidden_size, batch_size, kernel_size, input_height, input_width
):
    model = nn.Sequential(
        nn.Conv2d(in_size, hidden_size, kernel_size=kernel_size),
        nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size),
        FlattenAlongHeightWidth(),
        nn.Linear(hidden_size, out_size),
    )

    cross_layer_correlation_loss = CrossLayerCorrelationLoss(
        pair_hook_containers=[
            SinglePairHookContainer(layer1=model[1], layer2=model[3], scale=1.0)
        ]
    )

    input_tensor = torch.randn(batch_size, in_size, input_height, input_width)

    # forward pass
    output = model(input_tensor)

    loss = cross_layer_correlation_loss.compute()
    assert torch.is_tensor(loss)


@pytest.mark.parametrize("in_size", conv_layer_possible_params["in_size"])
@pytest.mark.parametrize("out_size", conv_layer_possible_params["out_size"])
@pytest.mark.parametrize("hidden_size", conv_layer_possible_params["hidden_size"])
@pytest.mark.parametrize("kernel_size", conv_layer_possible_params["kernel_size"])
@pytest.mark.parametrize("input_height", input_heights)
@pytest.mark.parametrize("input_width", input_widths)
@pytest.mark.parametrize("batch_size", batch_sizes)
def test_on_conv_and_linear_mix_loss_going_down(
    in_size,
    out_size,
    hidden_size,
    batch_size,
    kernel_size,
    input_height,
    input_width,
    num_train_steps=10,
):
    model = nn.Sequential(
        nn.Conv2d(in_size, hidden_size, kernel_size=kernel_size),
        nn.Conv2d(hidden_size, hidden_size, kernel_size=kernel_size),
        FlattenAlongHeightWidth(),
        nn.Linear(hidden_size, out_size),
    )

    cross_layer_correlation_loss = CrossLayerCorrelationLoss(
        pair_hook_containers=[
            SinglePairHookContainer(layer1=model[1], layer2=model[3], scale=1.0)
        ]
    )

    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
    losses = []

    for train_step_idx in range(num_train_steps):
        # forward pass
        input_tensor = torch.randn(batch_size, in_size, input_height, input_width)
        output = model(input_tensor)

        optimizer.zero_grad()
        loss = cross_layer_correlation_loss.compute()
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

    assert losses[-1] <= losses[0]
