from unittest import TestCase

import torch

from chroma.layers.norm import MaskedBatchNorm1d


class TestBatchNorm(TestCase):
    def test_norm(self):
        device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
        B, C, L = (3, 5, 7)
        x1 = torch.randn(B, C, L).to(device)
        mean1 = x1.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (B * L)
        var1 = ((x1 - mean1) ** 2).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (
            B * L
        )
        x2 = torch.randn(B, C, L).to(device)
        mean2 = x2.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (B * L)
        var2 = ((x2 - mean2) ** 2).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (
            B * L
        )

        mbn = MaskedBatchNorm1d(C)
        mbn = mbn.to(device)

        # Test without mask in train
        mbn.train()
        out = mbn(x1)
        self.assertTrue(mean1.allclose(mbn.running_mean))
        self.assertTrue(var1.allclose(mbn.running_var))
        normed = (x1 - mean1) / torch.sqrt(var1 + mbn.eps) * mbn.weight + mbn.bias
        self.assertTrue(normed.allclose(out))
        out = mbn(x2)
        normed = (x2 - mean2) / torch.sqrt(var2 + mbn.eps) * mbn.weight + mbn.bias
        self.assertTrue(normed.allclose(out))
        self.assertTrue(
            mbn.running_mean.allclose((1 - mbn.momentum) * mean1 + mbn.momentum * mean2)
        )
        self.assertTrue(
            mbn.running_var.allclose((1 - mbn.momentum) * var1 + mbn.momentum * var2)
        )

        # Without mask in eval
        mbn.eval()
        out = mbn(x1)
        self.assertTrue(
            mbn.running_mean.allclose((1 - mbn.momentum) * mean1 + mbn.momentum * mean2)
        )
        self.assertTrue(
            mbn.running_var.allclose((1 - mbn.momentum) * var1 + mbn.momentum * var2)
        )
        normed = (x1 - mbn.running_mean) / torch.sqrt(
            mbn.running_var + mbn.eps
        ) * mbn.weight + mbn.bias
        self.assertTrue(normed.allclose(out))

        # Check that masking with all ones doesn't change values
        mask = x1.new_ones((B, 1, L))
        outm = mbn(x1, input_mask=mask)
        self.assertTrue(outm.allclose(out))
        mbn.eval()
        out = mbn(x2)
        outm = mbn(x2, input_mask=mask)
        self.assertTrue(outm.allclose(out))

        # With mask in train
        mask = torch.randn(B, 1, L)
        mask = mask > 0.0
        mask = mask.to(device)
        n = mask.sum()
        mean1 = (x1 * mask).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / n
        var1 = (((x1 * mask) - mean1) ** 2).sum(dim=0, keepdim=True).sum(
            dim=2, keepdim=True
        ) / n
        mbn = MaskedBatchNorm1d(C)
        mbn = mbn.to(device)
        mbn.train()
        out = mbn(x1, input_mask=mask)
        self.assertTrue(mean1.allclose(mbn.running_mean))
        self.assertTrue(var1.allclose(mbn.running_var))
        normed = (x1 * mask - mean1) / torch.sqrt(
            var1 + mbn.eps
        ) * mbn.weight + mbn.bias
        self.assertTrue(normed.allclose(out))
        # With mask in eval
        mbn.eval()
        out = mbn(x1, input_mask=mask)
        normed = (x1 * mask - mbn.running_mean) / torch.sqrt(
            mbn.running_var + mbn.eps
        ) * mbn.weight + mbn.bias
        self.assertTrue(normed.allclose(out))
