import unittest

import torch
from torch.nn import functional as F

from sngp.model import SNGP_WideResNet28_10_cifar


class TestSNGP(unittest.TestCase):
    def test_smoketest_sngp(self) -> None:
        x, y = torch.randn(128, 3, 32, 32).cuda(), torch.randint(0, 9, (128,)).cuda()

        model = SNGP_WideResNet28_10_cifar(resnet_kwargs=dict(num_classes=10), sngp_kwargs=dict(num_classes=10)).cuda()
        opt = torch.optim.SGD(model.parameters(), lr=4e-2, momentum=0.9, weight_decay=5e-4)

        model.train()
        for i in range(5):
            opt.zero_grad()
            logits = model(x)
            self.assertEqual(logits.size(0), 128)
            self.assertEqual(logits.size(1), 10)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            opt.step()

        model.eval()
        with torch.no_grad():
            prior_prec = model.classes * model.gp_h_dim * model.s
            self.assertEqual(model.prec.sum(), prior_prec)
            logits = model(x, update_prec=True, y=y)
            self.assertNotEqual(model.prec.sum(), prior_prec)

            self.assertEqual(model.cov.sum(), 0)
            model.compute_cov()
            self.assertNotEqual(model.cov.sum(), 0)

            _, logit = model.mc(x.cuda())
            loss = F.cross_entropy(logit, y.cpu())
