from datasets.cifar import CIFAR10SimCLRDataModule
import torch
import pytest
import pytorch_lightning as pl
from model_zoo.simclr import SimCLRResNet50Model


class TestModels:
    dm = CIFAR10SimCLRDataModule()
    MODELS = [SimCLRResNet50Model(datamodule=dm)]

    @pytest.mark.parametrize("model", MODELS)
    def test_inference(self, model):
        batch_size = self.dm.batch_size
        x = torch.rand(batch_size, 3, 224, 224)
        assert model(x).shape == (batch_size, model.hidden_dim)

    @pytest.mark.parametrize("model", MODELS)
    def test_train(self, model):
        trainer = pl.Trainer(fast_dev_run=True, devices=[1])
        trainer.fit(model, datamodule=self.dm)
