import unittest
from typing import List, Tuple

import torch
from torch import nn

from layers.bnn import BayesianLayer, BayesianNet
from layers.spectral_norm import SpectralNorm
from layers.spectral_norm_conv import SpectralNormConv

T = torch.Tensor


class Linear(BayesianNet):
    def __init__(self) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            BayesianLayer(nn.Linear(10, 10)),
            nn.LeakyReLU(inplace=True),
            BayesianLayer(nn.Linear(10, 1))
        )

    def forward(self, x: T) -> T:
        return self.layers(x)  # type: ignore


class Conv(BayesianNet):
    def __init__(self) -> None:
        super().__init__()
        lyrs = []
        for i in range(3):
            lyrs.extend([
                BayesianLayer(nn.Conv2d(3 if i == 0 else 128, 128, 3)),
                BayesianLayer(nn.BatchNorm2d(128)),
                nn.LeakyReLU(inplace=True),
                nn.MaxPool2d(2, 2)
            ])

        lyrs.extend([nn.Flatten(), BayesianLayer(nn.Linear(128, 100))])
        self.layers = nn.Sequential(*lyrs)

    def forward(self, x: T) -> T:
        return self.layers(x)  # type: ignore


def get_params_before_after(net: nn.Module, loss: T) -> Tuple[List[T], List[T]]:
    params_before = []
    for n, p in net.named_parameters():
        if p.requires_grad:
            params_before.append(torch.clone(p.grad) if p.grad is not None else torch.zeros_like(p))

    loss.backward()

    params_after = []
    for n, p in net.named_parameters():
        if p.requires_grad:
            params_after.append(p.grad if p.grad is not None else torch.zeros_like(p))

    return params_before, params_after


class TestBayesian(unittest.TestCase):
    def test_bayesian_layer(self) -> None:
        """
        These tests use the `changed` parameter to make sure that at least one parameter has changed.
        It was periodically failing sometimes because a single layer had no parameter change which must be possible.
        Since every parameter is the same and every parameter is a Bayesian parameter, then we can just verify
        that one has changed and that is good enough
        """
        lin, conv = Linear(), Conv()

        linopt, convopt = torch.optim.Adam(lin.parameters()), torch.optim.Adam(conv.parameters())
        for (net, opt, xsize, loss_fn, y) in zip(
            [lin, conv],
            [linopt, convopt],
            [(32, 10), (32, 3, 28, 28)],
            [nn.MSELoss(), nn.CrossEntropyLoss()],
            [torch.randn(32), torch.randint(0, 100, (32,))],
        ):

            for i in range(2):
                x = torch.randn(*xsize)
                out = net(x)

                # TEST loss is getting through to the Bayesian parameters ==========================
                loss = loss_fn(out.squeeze(1), y)
                opt.zero_grad()

                params_before, params_after = get_params_before_after(net, loss)
                changed = False
                for before, after in zip(params_before, params_after):
                    if before.sum() != after.sum():
                        changed = True
                self.assertTrue(changed)

                _ = net.kl()  # zero out the kl from this pass

                # TEST kl is getting through to the Bayesian parameters ===========================
                x = torch.randn(*xsize)
                out = net(x)

                opt.zero_grad()

                changed = False
                params_before, params_after = get_params_before_after(net, net.kl())
                for before, after in zip(params_before, params_after):
                    if before.sum() != after.sum():
                        changed = True
                self.assertTrue(changed)


class TestSpectralNorm(unittest.TestCase):
    def test_smoketest_spectral_norm(self) -> None:
        for t in ["none", "scalar", "vector"]:
            lyr = SpectralNorm(nn.Linear(64, 128), ctype=t)
            ins = torch.randn(32, 64)

            w1 = torch.clone(lyr.base_layer.weight)  # type: ignore
            outs = lyr(ins)
            w2 = torch.clone(lyr.base_layer.weight)  # type: ignore
            self.assertTrue((w1 - w2).sum().item() != 0)  # test that the weight changed in the process
            self.assertEqual(outs.size(0), 32)
            self.assertEqual(outs.size(1), 128)

            lyr.eval()
            outs = lyr(ins)
            self.assertEqual(outs.size(0), 32)
            self.assertEqual(outs.size(1), 128)

    def test_smoketest_conv_spectral_norm(self) -> None:
        for t in ["none", "scalar", "vector"]:
            lyr = SpectralNormConv(nn.Conv2d(64, 128, kernel_size=3), (128, 64, 28, 28), ctype=t)
            ins = torch.randn(32, 64, 28, 28)
            w1 = torch.clone(lyr.base_layer.weight)  # type: ignore
            outs = lyr(ins)
            w2 = torch.clone(lyr.base_layer.weight)  # type: ignore
            self.assertTrue((w1 - w2).sum().item() != 0)  # test that the weight changed in the process
            self.assertEqual(outs.size(0), 32)
            self.assertEqual(outs.size(1), 128)

            lyr.eval()
            outs = lyr(ins)
            self.assertEqual(outs.size(0), 32)
            self.assertEqual(outs.size(1), 128)
