import unittest
from argparse import Namespace

import torch

from maml.model import MAMLMiniImageNet, MAMLOmniglot


class TestMAMLModels(unittest.TestCase):
    def test_smoketest_omniglot(self) -> None:
        args = Namespace(n_way=5)
        model = MAMLOmniglot(args)
        x = torch.randn(32, 1, 28, 28)
        out = model(x)

        self.assertEqual(out.size(0), 32)
        self.assertEqual(out.size(1), 5)

    def test_smoketest_miniimagenet(self) -> None:
        args = Namespace(n_way=5)
        model = MAMLMiniImageNet(args)
        x = torch.randn(32, 3, 84, 84)
        out = model(x)
        self.assertEqual(out.size(0), 32)
        self.assertEqual(out.size(1), 5)
