import torch
import math


class Normal:
    def __init__(self, mean, covariance):
        assert mean.dim() == 2 and covariance.dim() == 2 and mean.shape[1] == 1
        self.mean = mean
        self.covariance = covariance
        self.precision = torch.inverse(covariance)
        self.cholesky_lower_covariance = torch.cholesky(covariance)
        self.d = mean.shape[0]
        self.log_det_covariance = torch.logdet(covariance)

    def log_prob(self, x):
        assert x.shape[1] == self.d
        diff = x - self.mean.T
        M = torch.sum(diff.mm(self.precision) * diff, 1)
        return -0.5 * (self.d * math.log(2 * math.pi) + M + self.log_det_covariance)

    def sample(self, n):
        '''
        Sample via reparameterization
        :param n: number of desired samples
        :return: n samples in an (n, d) tensor
        '''
        Eps = torch.empty((self.d, n), dtype=self.mean.dtype, device=self.mean.device).normal_()
        return Eps.T.mm(self.cholesky_lower_covariance.T) + self.mean.T


if __name__ == '__main__':
    from torch.distributions import MultivariateNormal

    mean = torch.Tensor([[2.], [0.]])
    covariance = torch.Tensor([[2., 1.],
                               [1., 2.]])
    p = Normal(mean, covariance)

    # check decomposition
    assert torch.allclose(p.cholesky_lower_covariance.mm(p.cholesky_lower_covariance.T), covariance)

    # check log_prob
    x = torch.tensor([[1., 1.]], requires_grad=True)
    q = MultivariateNormal(loc=mean.squeeze(), covariance_matrix=covariance)
    assert torch.allclose(p.log_prob(x), q.log_prob(x.squeeze()))

    # check sampling
    n = 1000000
    samples = p.sample(n)
    estimated_mean = samples.mean(0, keepdims=True)
    diff = samples - estimated_mean
    estimated_covariance = diff.T.mm(diff) / (n - 1)

    print(mean, estimated_mean)
    print(covariance, estimated_covariance)