import torch

from pdisvae.models import dcnn


def test_burgess_decoder():
    img_size = (1, 32, 32)
    n_components = 4
    decoder = dcnn.BurgessDecoder(img_size, n_components)

    # Test forward
    z = torch.randn(1, n_components)
    x_pred_mean = decoder(z)
    assert x_pred_mean.shape == (1, *img_size)
    assert x_pred_mean.max() <= 1
    assert x_pred_mean.min() >= 0


def test_burgess_encoder():
    img_size = (1, 32, 32)
    n_components = 4
    encoder = dcnn.BurgessEncoder(img_size, n_components)

    # Test forward
    x = torch.randn(1, *img_size)
    z_pred_mean, z_pred_log_std = encoder(x)
    assert z_pred_mean.shape == (1, n_components)
