import numpy as np
import torch

from pdisvae.models import linear


class TestLinearDecoder:
    def instantiate_decoder(self):
        obs_dim = 10
        n_components = 5
        decoder = linear.LinearDecoder(obs_dim, n_components)
        return decoder

    def test_initialize(self):
        decoder = self.instantiate_decoder()
        obs_dim, n_components = decoder.obs_dim, decoder.n_components

        x = torch.randn(100, obs_dim)
        decoder.initialize("ICA", x)

        assert decoder.mixing_and_bias.weight.shape == (obs_dim, n_components)
        assert decoder.mixing_and_bias.bias.shape == (obs_dim,)

    def test_forward(self):
        decoder = self.instantiate_decoder()
        obs_dim, n_components = decoder.obs_dim, decoder.n_components
        n_samples = 100

        z = torch.randn(n_samples, n_components)
        x_pred_mean = decoder.forward(z)

        assert x_pred_mean.shape == (n_samples, obs_dim)

    def test_log_prob(self):
        decoder = self.instantiate_decoder()
        obs_dim, n_components = decoder.obs_dim, decoder.n_components
        n_samples = 100

        x = torch.randn(n_samples, obs_dim)
        z = torch.randn(n_samples, n_components)
        x_pred_mean = decoder.forward(z)
        log_prob = decoder.log_prob(x_pred_mean, x)

        log_prob_ref = (
            -0.5
            * (
                np.log(2 * np.pi)
                + 2 * decoder.log_std
                + (x - x_pred_mean) ** 2 / decoder.var
            )
        ).sum(dim=-1)

        assert log_prob.shape == (n_samples,)
        torch.testing.assert_close(log_prob, log_prob_ref)
