from typing import Literal
import torch.nn as nn
import torch
import torch.nn.functional as F
import pytest
from nn_compression.networks import batch_norm_folding, extract_normalisation_gammas

import torch.fx as fx


class ExampleNet(nn.Module):
    """Example neural network with a tree-like computation graph and skip connections."""

    def __init__(self, norm: Literal["batch", "layer"] = "batch"):
        super(ExampleNet, self).__init__()
        if norm == "batch":
            norm_conv = nn.BatchNorm2d
            norm_linear = nn.BatchNorm1d
        else:
            norm_conv = nn.BatchNorm2d
            norm_linear = nn.LayerNorm
        self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1)
        self.bn1 = norm_conv(16)
        self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1)
        self.bn2 = norm_conv(32)
        self.conv3 = nn.Conv2d(32, 32, 3, 1, padding=1)
        self.bn3 = norm_conv(32)
        self.fc1 = nn.Linear(32 * 8 * 8, 64)
        self.bn4 = norm_linear(64)
        self.fc2 = nn.Linear(64, 10)
        self.bn5 = norm_linear(10)
        self.init_weights()

    def init_weights(self):
        nn.init.uniform_(self.bn1.weight, a=0.0, b=1.0)
        nn.init.uniform_(self.bn2.weight, a=0.0, b=1.0)
        nn.init.uniform_(self.bn3.weight, a=0.0, b=1.0)
        nn.init.uniform_(self.bn4.weight, a=0.0, b=1.0)
        nn.init.uniform_(self.bn5.weight, a=0.0, b=1.0)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        residual = x
        x = F.relu(self.bn3(self.conv3(x)) + residual)
        x = F.adaptive_avg_pool2d(x, (8, 8))
        x = torch.flatten(x, 1)
        x = F.relu(self.bn4(self.fc1(x)))
        x = self.bn5(self.fc2(x))
        return x


@pytest.mark.parametrize(
    "input_tensor",
    [torch.randn(4, 3, 32, 32), torch.randn(8, 3, 64, 64), torch.randn(4, 3, 16, 16)],
)
def test_folded_model_equality(input_tensor):
    model = ExampleNet()
    model.train()
    y_ = model(input_tensor)  # initialise some values for running mean and var
    y = torch.randn(input_tensor.shape[0], 10)
    model.eval()
    folded_model = batch_norm_folding(model)
    folded_model.eval()

    with torch.no_grad():
        original_output = model(input_tensor)
        folded_output = folded_model(input_tensor)

    assert torch.allclose(
        original_output, folded_output, atol=1e-6
    ), "Outputs are not the same!"


@pytest.mark.parametrize(
    "input_tensor",
    [torch.randn(4, 3, 32, 32), torch.randn(8, 3, 64, 64), torch.randn(4, 3, 16, 16)],
)
def test_batchnorm_folded_batchnorm_gone(input_tensor):
    model = ExampleNet()
    model.train()
    model(input_tensor)  # initialise some values for running mean and var
    model.eval()
    folded_model = batch_norm_folding(model)
    # trace graph with torchfx and see if any batchnorm is called
    fx_model = fx.symbolic_trace(folded_model)
    for node in fx_model.graph.nodes:
        assert node.op != "call_module" or node.target.startswith("bn") is False


@pytest.mark.parametrize(
    "input_tensor",
    [torch.randn(4, 3, 32, 32), torch.randn(8, 3, 64, 64), torch.randn(4, 3, 16, 16)],
)
def test_batchnorm_gamma_extraction(input_tensor):
    model = ExampleNet()
    model.train()
    model(input_tensor)  # initialise some values for running mean and var
    model.eval()

    gammas = extract_normalisation_gammas(model)
    gammas_by_hand = {
        "conv1": model.bn1.weight / torch.sqrt(model.bn1.running_var + model.bn1.eps),  # type: ignore
        "conv2": model.bn2.weight / torch.sqrt(model.bn2.running_var + model.bn2.eps),  # type: ignore
        "conv3": model.bn3.weight / torch.sqrt(model.bn3.running_var + model.bn3.eps),  # type: ignore
    }

    for g in gammas_by_hand.keys():
        assert torch.allclose(gammas[g], gammas_by_hand[g], atol=1e-6)


@pytest.mark.parametrize(
    "input_tensor",
    [torch.randn(4, 3, 32, 32), torch.randn(8, 3, 64, 64), torch.randn(4, 3, 16, 16)],
)
def test_layernorm_extraction(input_tensor):
    model = ExampleNet(norm="layer")
    model.train()
    model(input_tensor)  # initialise some values for running mean and var
    model.eval()

    gammas = extract_normalisation_gammas(model)
    gammas_by_hand = {
        "fc1": model.bn4.weight,
        "fc2": model.bn5.weight,
    }

    for g in gammas_by_hand.keys():
        assert torch.allclose(gammas[g], gammas_by_hand[g], atol=1e-6)
