import unittest

import torch

from protonet.model import ProtonetCNN4


class TestResNets(unittest.TestCase):
    def test_smoketest_protonet(self) -> None:
        model = ProtonetCNN4(1, 64, 64, 10)
        sx, sy, qx = torch.randn(32, 1, 28, 28), torch.randint(0, 10, (32,)), torch.randn(32, 1, 28, 28)
        out = model(sx, sy, qx)
        self.assertEqual(out.size(0), 32)
        self.assertEqual(out.size(1), 10)
