import torch

from pdisvae.models import lint


def test_decomposed_mixing():
    """Test the correctness of the DecomposedMixing class."""
    n_time_bins, n_neurons = 100, 10
    n_components = 6  # need to be <= n_neurons in general
    convolved_history = torch.randn(n_time_bins, n_neurons)
    decomposed_mixing = lint.DecomposedMixing(convolved_history)
    right_weight = torch.randn(n_neurons, n_components)

    mixing = decomposed_mixing.forward(right_weight)
    right_weight_recon = decomposed_mixing.right_inverse(mixing)

    assert right_weight_recon.shape == right_weight.shape
    torch.testing.assert_close(right_weight_recon, right_weight)


def test_lint_decoder():
    """Test the parameterization of the LintDecoder class."""
    n_time_bins, n_neurons = 100, 10
    n_components = 6
    convolved_history = torch.randn(n_time_bins, n_neurons)
    lint_decoder = lint.LintDecoder(convolved_history, n_components)

    z = torch.randn(100, n_components)
    x_pred_mean = lint_decoder.forward(z)
