import pytest
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split

from ccvae.metrics import compute_mmd
from ccvae.pl.trvae import TrVAE
from ccvae.nn.utils import Encoder, GaussianDecoder, MultinomialDecoder


def make_dataset(num_samples, delta, distribution_type):
    x_dim = 10
    assert distribution_type in ['normal', 'multinomial']
    if distribution_type == 'normal':
        X0 = torch.randn(num_samples, x_dim)
    else:
        X0 = torch.randint(0, 100, (num_samples, x_dim), dtype=torch.float32)
    X1 = X0.clone().detach()
    X1 += delta
    X = torch.vstack([X0, X1])
    y = torch.zeros(2 * num_samples, 2)
    y[:num_samples, 0] = 1
    y[num_samples:, 1] = 1
    # group labels d - 2 classes, each sample labelled alternately ([1, 0], [0, 1], [1, 0], ...])
    d = torch.zeros(2 * num_samples, 2)
    d[torch.arange(0, 2 * num_samples) % 2 == 0, 0] = 1.0
    d[torch.arange(0, 2 * num_samples) % 2 == 1, 1] = 1.0
    return TensorDataset(X, y, d)


@pytest.mark.parametrize("decoder_class", [GaussianDecoder, MultinomialDecoder])
@pytest.mark.parametrize("learn_sigma", ["fix_all", "learn_all", "learn_but_decouple", "learn_elbo_fix_penalty"])
@pytest.mark.parametrize("n_layers", [0, 1, 2])
@pytest.mark.parametrize("penalise_z", [True, False])
@pytest.mark.parametrize("forward_pass_use_groups", [False, True])
def test_trvae(decoder_class, learn_sigma, n_layers, penalise_z, forward_pass_use_groups):
    num_samples = 100
    z_dim = 2
    hidden_dim = 5

    assert decoder_class in [GaussianDecoder, MultinomialDecoder]
    distribution_type = 'normal' if decoder_class == GaussianDecoder else 'multinomial'
    dataset = make_dataset(num_samples // 2, 5, distribution_type)
    input_dim = dataset.tensors[0].shape[1]
    num_classes = dataset.tensors[1].shape[1]
    num_groups = dataset.tensors[2].shape[1]
    encoder_input_dim = input_dim + num_classes
    decoder_input_dim = z_dim + num_classes
    if forward_pass_use_groups:
        encoder_input_dim += num_groups
        decoder_input_dim += num_groups

    encoder = Encoder(encoder_input_dim, z_dim, hidden_dim, learn_sigma=learn_sigma, n_layers=n_layers, use_batchnorm=True)
    decoder = decoder_class(input_dim, decoder_input_dim, hidden_dim, return_hidden=True, use_batchnorm=True)
    model = TrVAE(encoder, decoder, z_dim, beta=1.2, n_groups=num_groups,
                  forward_pass_use_groups=forward_pass_use_groups, penalise_z=penalise_z)

    trainset, validset = random_split(dataset, [num_samples - num_samples // 10, num_samples // 10])
    train_loader = DataLoader(trainset, batch_size=num_samples, shuffle=True, num_workers=0)
    trainer = pl.Trainer(max_epochs=1)
    trainer.fit(model, train_loader)

    z = model.forward(*dataset.tensors).sample()
    assert z.shape == torch.Size([num_samples, z_dim])


@pytest.mark.parametrize("n_layers", [1, 2])
@pytest.mark.parametrize("penalise_z", [True, False])
@pytest.mark.skip(reason="Too flaky")
def test_mmd_penalty(n_layers, penalise_z):
    num_samples = 100
    z_dim = 2
    hidden_dim = 5

    dataset = make_dataset(num_samples // 2, 5, 'normal')
    input_dim = dataset.tensors[0].shape[1]
    num_classes = dataset.tensors[1].shape[1]
    encoder_input_dim = input_dim + num_classes
    decoder_input_dim = z_dim + num_classes

    class_labels = dataset.tensors[1]
    assert class_labels.shape[1] == num_classes
    group_labels = dataset.tensors[2]
    assert group_labels.shape[1] == num_groups
    class0_mask = class_labels[:, 0].to(torch.bool)
    class1_mask = class_labels[:, 1].to(torch.bool)
    group0_mask = group_labels[:, 0].to(torch.bool)
    group1_mask = group_labels[:, 1].to(torch.bool)

    mmd_scores = []
    for scale in [0.0, 100.0]:
        torch.manual_seed(0)  # Fix seed so both models start from the same parameters
        encoder = Encoder(encoder_input_dim, z_dim, hidden_dim, learn_sigma="fix_all", n_layers=n_layers, use_batchnorm=True)
        decoder = GaussianDecoder(input_dim, decoder_input_dim, hidden_dim, n_layers=n_layers, return_hidden=True, use_batchnorm=True)
        model = TrVAE(encoder, decoder, z_dim, penalty_scale=scale, learning_rate=0.1, beta=1.0, n_groups=num_groups, penalise_z=penalise_z)
        trainset, validset = random_split(
            dataset, [num_samples - num_samples // 10, num_samples // 10]
        )
        train_loader = DataLoader(trainset, batch_size=num_samples, shuffle=True, num_workers=0)
        trainer = pl.Trainer(max_epochs=100)
        trainer.fit(model, train_loader)
        z = model.forward(*dataset.tensors).sample()
        zc = torch.cat([z, *dataset.tensors[1:]], dim=-1)
        # detach because compute_mmd operates on ndarrays and y is a tensor with requires_grad=True
        y = model.decoder.forward(zc)[1].detach()

        mmd_0 = compute_mmd(y[group0_mask * class0_mask, :],
                            y[group0_mask * class1_mask, :])
        mmd_1 = compute_mmd(y[group1_mask * class0_mask, :],
                            y[group1_mask * class1_mask, :])
        assert num_groups == 2, 'mmd calc out of date wrt num_groups; expected 2'
        mmd_scores.append((mmd_0 + mmd_1) / num_groups)
    assert mmd_scores[0] > mmd_scores[1]
