import pytest
import torch
import torchvision.models as models
from nesim.losses.nesim_loss import (
    NesimConfig,
    NesimLoss,
    NeighbourhoodCosineSimilarity,
    LaplacianPyramid,
)

model = models.resnet18(weights=None)
n_training_steps = 4
scales_possible_values_for_each_layer = {
    "layer4.0.conv1": [0.1, 0.1, 1, None],
    "layer3.1.conv2": [0.2, None, None, None],
    "fc": [0.3, 0.3, None, None],
}

all_scale_values_list = []

for idx in range(len(scales_possible_values_for_each_layer["fc"])):
    single_item = {}
    for key in scales_possible_values_for_each_layer:
        single_item[key] = scales_possible_values_for_each_layer[key][idx]
    all_scale_values_list.append(single_item)


@pytest.mark.parametrize("scales", all_scale_values_list)
def test_loss_resnet18_conv_and_linear_layers(scales: dict):

    layer_wise_configs = [
        ## scale = None -> just watch the layer's loss, do not backprop
        NeighbourhoodCosineSimilarity(layer_name=key, scale=value)
        for key, value in scales.items()
    ] + [
        LaplacianPyramid(layer_name=key, scale=value, shrink_factor=[2.0])
        for key, value in scales.items()
    ]

    config = NesimConfig(
        layer_wise_configs=layer_wise_configs,
    )

    loss = NesimLoss(model=model, config=config, device="cpu")
    mean_loss = loss.compute(reduce_mean=True)
    if list(scales.values()) != [None, None, None]:
        assert torch.is_tensor(mean_loss)
    else:
        assert mean_loss == None

    dict_loss = loss.compute(reduce_mean=False)
    assert isinstance(dict_loss, dict)


@pytest.mark.parametrize("scales", all_scale_values_list)
def test_loss_resnet18_conv_and_linear_layers_training(scales: dict):

    layer_wise_configs = [
        ## scale = None -> just watch the layer's loss, do not backprop
        NeighbourhoodCosineSimilarity(layer_name=key, scale=value)
        for key, value in scales.items()
    ] + [
        LaplacianPyramid(layer_name=key, scale=value, shrink_factor=[2.0])
        for key, value in scales.items()
    ]

    config = NesimConfig(
        layer_wise_configs=layer_wise_configs,
    )

    nesim_loss_handler = NesimLoss(model=model, config=config, device="cpu")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    losses = []

    for i in range(n_training_steps):
        optimizer.zero_grad()
        loss = nesim_loss_handler.compute()
        if loss is not None:
            losses.append(loss.item())
            loss.backward()
            optimizer.step()

    if len(losses) > 0:
        assert losses[-1] < losses[0]
