import pytest
import torch
import torch.nn as nn
from nesim.bimt.loss import BIMTLoss, BIMTConfig, BrainInspiredLayer
from nesim.utils.getting_modules import get_module_by_name

size_sequences = [[3, 10, 20], [1, 3, 3], [1, 3, 1], [1, 10, 1], [1, 1, 1]]
kernel_sizes = [1, 3, 5, 7]
possible_num_train_steps = [5, 10]


@pytest.mark.parametrize("size_sequence", size_sequences)
@pytest.mark.parametrize("kernel_size", kernel_sizes)
def test_bimt_loss_conv(size_sequence: list, kernel_size: int):
    layers = [
        nn.Conv2d(size_sequence[i], size_sequence[i + 1], kernel_size=kernel_size)
        for i in range(len(size_sequence) - 1)
    ]
    fake_model = nn.Sequential(*layers)

    bimt = BIMTLoss(
        layer_names=[str(i) for i in range(len(size_sequence) - 1)],
        distance_between_nearby_layers=0.2,
    )
    fake_model = bimt.init_modules_for_training(model=fake_model)
    assert torch.is_tensor(bimt.forward(model=fake_model))
    for i in range(len(size_sequence) - 1):
        brain_inspired_layer = get_module_by_name(module=fake_model, name=str(i))
        assert isinstance(brain_inspired_layer, BrainInspiredLayer)
        assert isinstance(
            brain_inspired_layer.layer, nn.Conv2d
        ), f"Expected BrainInspiredLayer().layer to be nn.Conv2d but got: {type(brain_inspired_layer.layer)}"


@pytest.mark.parametrize("size_sequence", size_sequences)
def test_bimt_loss_linear(size_sequence: list):
    layers = [
        nn.Linear(size_sequence[i], size_sequence[i + 1])
        for i in range(len(size_sequence) - 1)
    ]
    fake_model = nn.Sequential(*layers)
    bimt = BIMTLoss(
        layer_names=[str(i) for i in range(len(size_sequence) - 1)],
        distance_between_nearby_layers=0.2,
    )
    fake_model = bimt.init_modules_for_training(model=fake_model)
    assert torch.is_tensor(bimt.forward(model=fake_model))
    for i in range(len(size_sequence) - 1):
        brain_inspired_layer = get_module_by_name(module=fake_model, name=str(i))
        assert isinstance(brain_inspired_layer, BrainInspiredLayer)
        assert isinstance(
            brain_inspired_layer.layer, nn.Linear
        ), f"Expected BrainInspiredLayer().layer to be nn.Linear but got: {type(brain_inspired_layer.layer)}"


@pytest.mark.parametrize("size_sequence", size_sequences)
@pytest.mark.parametrize("kernel_size", kernel_sizes)
def test_bimt_loss_conv_and_linear(size_sequence: list, kernel_size: int):
    layers = [
        nn.Conv2d(size_sequence[i], size_sequence[i + 1], kernel_size=kernel_size)
        for i in range(len(size_sequence) - 1)
    ]
    layers.extend(
        [
            nn.Linear(size_sequence[i], size_sequence[i + 1])
            for i in range(len(size_sequence) - 1)
        ]
    )
    fake_model = nn.Sequential(*layers)
    bimt = BIMTLoss(
        layer_names=[str(i) for i in range(len(size_sequence) - 1)],
        distance_between_nearby_layers=0.2,
    )
    fake_model = bimt.init_modules_for_training(model=fake_model)
    assert torch.is_tensor(bimt.forward(model=fake_model))


@pytest.mark.parametrize("size_sequence", size_sequences)
@pytest.mark.parametrize("kernel_size", kernel_sizes)
def test_bimt_from_config(size_sequence: list, kernel_size: int):
    layers = [
        nn.Conv2d(size_sequence[i], size_sequence[i + 1], kernel_size=kernel_size)
        for i in range(len(size_sequence) - 1)
    ]
    layers.extend(
        [
            nn.Linear(size_sequence[i], size_sequence[i + 1])
            for i in range(len(size_sequence) - 1)
        ]
    )

    model = nn.Sequential(*layers)
    config = BIMTConfig(
        layer_names=["0", "1", "2"],
        scale=1.0,
        distance_between_nearby_layers=0.2,
        device="cpu",
    )

    bimt = BIMTLoss.from_config(config=config)
    model = bimt.init_modules_for_training(model=model)
    assert torch.is_tensor(bimt.forward(model=model))


@pytest.mark.parametrize("num_train_steps", possible_num_train_steps)
@pytest.mark.parametrize("size_sequence", size_sequences)
@pytest.mark.parametrize("kernel_size", kernel_sizes)
def test_bimt_going_down(size_sequence: list, kernel_size: int, num_train_steps: int):
    """
    This is the most important test, it makes sure that BIMT actually works and the loss goes down
    """
    layers = [
        nn.Conv2d(size_sequence[i], size_sequence[i + 1], kernel_size=kernel_size)
        for i in range(len(size_sequence) - 1)
    ]
    layers.extend(
        [
            nn.Linear(size_sequence[i], size_sequence[i + 1])
            for i in range(len(size_sequence) - 1)
        ]
    )

    model = nn.Sequential(*layers)

    config = BIMTConfig(
        layer_names=["1"], scale=1.0, distance_between_nearby_layers=1.0, device="cpu"
    )

    bimt = BIMTLoss.from_config(config=config)
    model = bimt.init_modules_for_training(model=model)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    losses = []
    for train_step_idx in range(num_train_steps):
        optimizer.zero_grad()
        loss = bimt.forward(
            model=model
        )  ## take the model as the arg here, possible soln)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

    assert losses[-1] < losses[0]
