import unittest
from unittest import TestCase
from functools import reduce
import copy

import torch

from src.dl.models.cell_types.residual import ResNetLayer


class TestResNetLayer(TestCase):
    def test_forward(self):
        """Make sure grad update changes all learnable parameters in the module."""
        # Instantiate the resnet.
        n_channels = 16
        n_inner_channels = 64
        resnet_layer = ResNetLayer(n_channels=n_channels, n_inner_channels=n_inner_channels)

        # Create a dummy input batch.
        batch_size = 15
        im_side = 32
        z = torch.randn((batch_size, n_channels, im_side, im_side))
        x = torch.randn((batch_size, n_channels, im_side, im_side))

        # Run forward pass.
        out = resnet_layer(z=z, x=x)
        # Compute an arbitrary loss and backprop.
        loss = torch.mean(torch.sum(out ** 2, dim=(1, 2, 3)), dim=0)
        loss.backward()

        # Check parameter grads and make sure they are not zero.
        all_grads_nonzero = all([not torch.allclose(p.grad, torch.zeros_like(p)) for p in resnet_layer.parameters()])
        self.assertTrue(all_grads_nonzero)

    def test_weight_norm(self):
        """Make sure that weight norm is working as intended."""
        # ____ Make sure that doubling the weight values don't lead to a change in the output. ____
        # Instantiate the resnet.
        n_channels = 16
        n_inner_channels = 64
        resnet_layer = ResNetLayer(n_channels=n_channels, n_inner_channels=n_inner_channels, use_weight_norm=True)

        # Create a dummy input batch.
        batch_size = 15
        im_side = 32
        z = torch.randn((batch_size, n_channels, im_side, im_side))
        x = torch.randn((batch_size, n_channels, im_side, im_side))

        # Run forward pass.
        out1 = resnet_layer(z=z, x=x)

        # Change the magnitude of the parameters without changing their directions.
        with torch.no_grad():
            for n, p in resnet_layer.named_parameters():
                if "conv_g" in n:
                    p.mul_(2)

        # Rerun forward pass.
        out2 = resnet_layer(z=z, x=x)

        # Make sure that the outputs are the same.
        print(torch.allclose(out1 - out2, torch.zeros_like(out1)))


if __name__ == "__main__":
    """
    Run from root. 
    python -m unittest -v src.dl.models.cell_types.test_residual
    """
    unittest.main()
