import pytest
import torch

from sde.models import load_trained_vgg


@pytest.mark.parametrize(
    "ckpt_path, num_classes", [("../lightning_logs/version_4/checkpoints/epoch=49-step=150.ckpt", 4)])
def test_load_trained_vgg(ckpt_path, num_classes):
    """
    test if loading works
    """
    vgg = load_trained_vgg(ckpt_path, num_classes, device='cpu')
    assert vgg is not None
    assert isinstance(vgg, torch.nn.Module)
