import pytest
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split

from ccvae.pl.vae import VAE
from ccvae.nn.utils import Encoder, GaussianDecoder, MultinomialDecoder

@pytest.fixture
def dataset():
    num_samples = 100
    x_dim = 10
    X0 = torch.randn(num_samples, x_dim)
    X1 = X0.clone().detach()
    X1 += 5
    X = torch.vstack([X0, X1])
    y = torch.zeros(2*num_samples, 2)
    y[:num_samples, 0] = 1
    y[num_samples:, 1] = 1
    return TensorDataset(X, y)


@pytest.mark.parametrize("decoder_class", [GaussianDecoder, MultinomialDecoder])
@pytest.mark.parametrize("learn_sigma", ["fix_all", "learn_all", "learn_but_decouple", "learn_elbo_fix_penalty"])
@pytest.mark.parametrize("n_layers", [0, 1, 2])
def test_vae(dataset, decoder_class, learn_sigma, n_layers):
    z_dim = 2
    hidden_dim = 5
    num_samples = len(dataset)
    input_dim = dataset.tensors[0].shape[1]

    encoder = Encoder(input_dim, z_dim, hidden_dim, learn_sigma, n_layers)
    decoder = decoder_class(input_dim, z_dim, hidden_dim)  # TODO: n_layers = 0 not implemented, hence not tested
    model = VAE(encoder, decoder, z_dim)

    trainset, validset = random_split(dataset, [num_samples - num_samples // 10, num_samples // 10])
    train_loader = DataLoader(trainset, batch_size=5, shuffle=True, num_workers=0)
    trainer = pl.Trainer(max_epochs=1)
    trainer.fit(model, train_loader)
    z_dist = model.forward(dataset.tensors[0])
    z = z_dist.sample()
    assert z.shape == torch.Size([num_samples, z_dim])
