import unittest
from nfmc_jax.DLA.posterior import JaxPosterior
import jax.numpy as jnp
import jax.scipy
import scipy.stats
import numpy as np
import torch


class JaxPosteriorTestCase(unittest.TestCase):
    def test_output_shape(self):
        np.random.seed(0)
        torch.manual_seed(0)

        n_samples = 100
        n_dim = 3
        scale = 0.2
        lower = -1
        upper = 1

        mu = jnp.zeros(n_dim)
        icov = scipy.stats.wishart(df=n_dim, scale=np.eye(n_dim) * scale, seed=0).rvs(size=1)

        jax_log_likelihood_const = jnp.log(1 / np.sqrt(((2 * np.pi) ** n_dim) / jnp.linalg.det(icov)))
        jax_log_likelihood = lambda x: jax_log_likelihood_const - 0.5 * jnp.dot((x - mu), jnp.dot(icov, x - mu))
        jax_log_prior = lambda x: jnp.sum(jax.scipy.stats.uniform.logpdf(x, loc=lower, scale=upper - lower))

        posterior = JaxPosterior(
            log_likelihood=jax.vmap(jax_log_likelihood),
            log_prior=jax.vmap(jax_log_prior)
        )

        x0 = torch.tensor(np.random.uniform(low=lower, high=upper, size=(n_samples, n_dim)))

        log_likelihood = posterior.log_likelihood(x0)
        log_likelihood_gradient = posterior.log_likelihood_gradient(x0)

        log_prior = posterior.log_prior(x0)
        log_prior_gradient = posterior.log_prior_gradient(x0)

        # Check for correct shapes
        self.assertEqual(log_likelihood.shape, (n_samples,))
        self.assertEqual(log_likelihood_gradient.shape, (n_samples, n_dim))
        self.assertEqual(log_prior.shape, (n_samples,))
        self.assertEqual(log_prior_gradient.shape, (n_samples, n_dim))

    def test_output_type(self):
        np.random.seed(0)
        torch.manual_seed(0)

        n_samples = 100
        n_dim = 3
        scale = 0.2
        lower = -1
        upper = 1

        mu = jnp.zeros(n_dim)
        icov = scipy.stats.wishart(df=n_dim, scale=np.eye(n_dim) * scale, seed=0).rvs(size=1)

        jax_log_likelihood_const = jnp.log(1 / np.sqrt(((2 * np.pi) ** n_dim) / jnp.linalg.det(icov)))
        jax_log_likelihood = lambda x: jax_log_likelihood_const - 0.5 * jnp.dot((x - mu), jnp.dot(icov, x - mu))
        jax_log_prior = lambda x: jnp.sum(jax.scipy.stats.uniform.logpdf(x, loc=lower, scale=upper - lower))

        posterior = JaxPosterior(
            log_likelihood=jax.vmap(jax_log_likelihood),
            log_prior=jax.vmap(jax_log_prior)
        )

        x0 = torch.tensor(np.random.uniform(low=lower, high=upper, size=(n_samples, n_dim)))

        log_likelihood = posterior.log_likelihood(x0)
        log_likelihood_gradient = posterior.log_likelihood_gradient(x0)

        log_prior = posterior.log_prior(x0)
        log_prior_gradient = posterior.log_prior_gradient(x0)

        # Check for correct types
        self.assertEqual(type(log_likelihood), torch.Tensor)
        self.assertEqual(type(log_likelihood_gradient), torch.Tensor)
        self.assertEqual(type(log_prior), torch.Tensor)
        self.assertEqual(type(log_prior_gradient), torch.Tensor)

    def test_compute_all(self):
        np.random.seed(0)
        torch.manual_seed(0)

        n_samples = 100
        n_dim = 3
        scale = 0.2
        lower = -1
        upper = 1

        mu = jnp.zeros(n_dim)
        icov = scipy.stats.wishart(df=n_dim, scale=np.eye(n_dim) * scale, seed=0).rvs(size=1)

        jax_log_likelihood_const = jnp.log(1 / np.sqrt(((2 * np.pi) ** n_dim) / jnp.linalg.det(icov)))
        jax_log_likelihood = lambda x: jax_log_likelihood_const - 0.5 * jnp.dot((x - mu), jnp.dot(icov, x - mu))
        jax_log_prior = lambda x: jnp.sum(jax.scipy.stats.uniform.logpdf(x, loc=lower, scale=upper - lower))

        x0 = torch.tensor(np.random.uniform(low=lower, high=upper, size=(n_samples, n_dim)))

        posterior0 = JaxPosterior(
            log_likelihood=jax.vmap(jax_log_likelihood),
            log_prior=jax.vmap(jax_log_prior)
        )

        log_likelihood = posterior0.log_likelihood(x0)
        log_likelihood_gradient = posterior0.log_likelihood_gradient(x0)

        log_prior = posterior0.log_prior(x0)
        log_prior_gradient = posterior0.log_prior_gradient(x0)

        posterior1 = JaxPosterior(
            log_likelihood=jax.vmap(jax_log_likelihood),
            log_prior=jax.vmap(jax_log_prior)
        )
        posterior1.compute_all(x0, compute_gradients=True)

        # Check for correct types
        self.assertTrue(torch.equal(posterior1._log_likelihood_cache, log_likelihood))
        self.assertTrue(torch.equal(posterior1._log_likelihood_gradient_cache, log_likelihood_gradient))
        self.assertTrue(torch.equal(posterior1._log_prior_cache, log_prior))
        self.assertTrue(torch.equal(posterior1._log_prior_gradient_cache, log_prior_gradient))


if __name__ == '__main__':
    unittest.main()
