'''
This module contains classes for various parameterization of multivariate normal distributions.

This is modified from the original version in the vbll package.

Reference:

    https://arxiv.org/abs/2404.11599
    
    https://github.com/VectorInstitute/vbll

'''
import numpy as np
import torch
import warnings


class Normal(torch.distributions.Distribution):
    '''
    Abstract class for Gaussian distributions.
    '''
    def __init__(self, batch_shape = ..., event_shape = ..., validate_args = None):
        super().__init__(batch_shape, event_shape, validate_args)
        
    @property
    def mean(self) -> torch.Tensor:
        '''
        Mean vector of the distribution.
        '''
        raise NotImplementedError
    
    @property
    def var(self) -> torch.Tensor:
        '''
        Variance vector of the distribution.
        '''
        raise NotImplementedError
    
    @property
    def chol_covariance(self) -> torch.Tensor:
        '''
        The lower triangular Cholesky decomposition of the covariance matrix.
        '''
        raise NotImplementedError
    
    @property
    def covariance_diagonal(self) -> torch.Tensor:
        '''
        Diagonal (vector) of the covariance matrix.
        '''
        raise NotImplementedError

    @property
    def covariance(self) -> torch.Tensor:
        '''
        Covariance matrix.
        '''
        raise NotImplementedError
    
    @property
    def precision(self) -> torch.Tensor:
        '''
        Precision matrix. It is the inverse of the covariance matrix.
        '''
        raise NotImplementedError
    
    @property
    def logdet_covariance(self) -> torch.Tensor:
        '''
        Log determinant of the covariance matrix.
        '''
        raise NotImplementedError

    @property
    def logdet_precision(self) -> torch.Tensor:
        '''
        Log determinant of the precision matrix.
        '''
        raise NotImplementedError
    
    @property
    def trace_covariance(self) -> torch.Tensor:
        '''
        Trace of the covariance matrix.
        '''
        raise NotImplementedError
    
    @property
    def trace_precision(self) -> torch.Tensor:
        '''
        Trace of the precision matrix.
        '''
        raise NotImplementedError
    
    def covariance_weighted_inner_prod(self, b: torch.Tensor, reduce_dim=True) -> torch.Tensor:
        '''
        Compute the covariance weighted inner product of a tensor b with the covariance matrix.
        
        Parameters
        -----------
        b: torch.Tensor [n,d,1]
            The tensor to compute the covariance weighted inner product with.
            
        reduce_dim: bool
            Whether to reduce the last dimension of the result.
        '''
        raise NotImplementedError
    
    def precision_weighted_inner_prod(self, b: torch.Tensor, reduce_dim=True) -> torch.Tensor:
        '''
        Compute the precision weighted inner product of a tensor b with the precision matrix.
        
        Parameters
        -----------
        b: torch.Tensor [n,d,1]
            The tensor to compute the precision weighted inner product with.
            
        reduce_dim: bool
            Whether to reduce the last dimension of the result.
        '''
        raise NotImplementedError
    
    def __add__(self, inp):
        raise NotImplementedError
    
    def __matmul__(self, inp):
        raise NotImplementedError
    
    def squeeze(self, idx: int):
        '''
        Squeeze the distribution along the specified dimension.
        '''
        raise NotImplementedError


def tp(M: np.ndarray) -> np.ndarray:
    '''
    Transpose the last two dimensions of a matrix.
    '''
    return M.transpose(-1,-2)

def sym(M: np.ndarray) -> np.ndarray:
    '''
    Symmetrize a matrix.
    '''
    return (M + tp(M))/2.

def gaussian_kl(p: Normal, q_mean=0.0, q_variance=1.0) -> torch.Tensor:
    '''
    Compute the KL divergence between two Gaussian distributions.
    One is parameterized by a mean vector and a covariance matrix,
    the other is parameterized by a mean vector and a scale parameter.
    
    Parameters
    -----------
    p: Normal
        the first Gaussian distribution
        
    q_mean: float or torch.Tensor [n,d] or [d]
        the mean of the second Gaussian distribution
        
    q_variance: float
        the variance of the second Gaussian distribution
    '''
    dim_feature = p.mean.shape[-1]
    mse_term = ((p.mean-q_mean) ** 2).sum(-1).sum(-1) / q_variance
    trace_term = (p.trace_covariance / q_variance).sum(-1)
    logdet_term = (dim_feature * np.log(q_variance) - p.logdet_covariance).sum(-1)

    return 0.5*(mse_term + trace_term + logdet_term) # currently exclude constant


class DiagonalNormal(torch.distributions.Normal):
    '''
    Creates a Gaussian distribution with a diagonal covariance matrix,
    parameterized by loc and scale.
        
    Parameters
    -----------
    loc: float or torch.Tensor [n, d] or [d]
        the tensor of per-element means

    scale: float or torch.Tensor [n, d] or [d]
        the tensor of per-element standard deviations
        
    Attributes
    -----------
    loc: torch.Tensor [n, d] or [d]
        mean of the distribution
        
    scale: torch.Tensor [n, d] or [d]
        standard deviation of the distribution
    '''
    def __init__(self, loc, scale):
        super(DiagonalNormal, self).__init__(loc, scale)

    @property
    def mean(self) -> torch.Tensor:
        '''
        Mean vector of the distribution.
        '''
        return self.loc

    @property
    def var(self) -> torch.Tensor:
        '''
        Variance vector of the distribution.
        '''
        return self.scale ** 2

    @property
    def chol_covariance(self) -> torch.Tensor:
        '''
        The lower triangular Cholesky decomposition of the covariance matrix.
        '''
        return torch.diag_embed(self.scale)

    @property
    def covariance_diagonal(self) -> torch.Tensor:
        '''
        Diagonal (vector) of the covariance matrix.
        '''
        return self.var

    @property
    def covariance(self) -> torch.Tensor:
        '''
        Covariance matrix.
        '''
        return torch.diag_embed(self.var)

    @property
    def precision_diagonal(self) -> torch.Tensor:
        '''
        Diagonal (vector) of the precision matrix.
        '''
        return (1./self.var)

    @property
    def precision(self) -> torch.Tensor:
        '''
        Precision matrix. It is the inverse of the covariance matrix.
        '''
        return torch.diag_embed(1./self.var)

    @property
    def logdet_covariance(self) -> torch.Tensor:
        '''
        Log determinant of the covariance matrix.
        '''
        return 2 * torch.log(self.scale).sum(-1)

    @property
    def logdet_precision(self) -> torch.Tensor:
        '''
        Log determinant of the precision matrix.
        '''
        return -2 * torch.log(self.scale).sum(-1)

    @property
    def trace_covariance(self) -> torch.Tensor:
        '''
        Trace of the covariance matrix.
        '''
        return self.var.sum(-1)

    @property
    def trace_precision(self) -> torch.Tensor:
        '''
        Trace of the precision matrix.
        '''
        return (1./self.var).sum(-1)

    def covariance_weighted_inner_prod(self, b: torch.Tensor, reduce_dim=True) -> torch.Tensor:
        '''
        Compute the covariance weighted inner product of a tensor b with the covariance matrix,
        which is equivalent to the element-wise product of b with the variance vector.
        
        b^T * Cov * b = (b*scale)^2
        
        Parameters
        -----------
        b: torch.Tensor [n,d,1]
            The tensor to compute the covariance weighted inner product with.
            
        reduce_dim: bool
            Whether to reduce the last dimension of the result.
        '''
        assert b.shape[-1] == 1
        prod = (self.var.unsqueeze(-1) * (b ** 2)).sum(-2)
        return prod.squeeze(-1) if reduce_dim else prod

    def precision_weighted_inner_prod(self, b, reduce_dim=True) -> torch.Tensor:
        '''
        Compute the precision weighted inner product of a tensor b with the precision matrix,
        which is equivalent to the element-wise division of b by the variance vector.
        
        b^T * Cov^{-1} * b = (b/scale)^2
        
        Parameters
        -----------
        b: torch.Tensor [n,d,1]
            The tensor to compute the precision weighted inner product with.
            
        reduce_dim: bool
            Whether to reduce the last dimension of the result.
        '''
        assert b.shape[-1] == 1
        prod = ((b ** 2)/self.var.unsqueeze(-1)).sum(-2)
        return prod.squeeze(-1) if reduce_dim else prod

    def __add__(self, inp):
        if isinstance(inp, DiagonalNormal):
            new_cov =  self.var + inp.var
            return DiagonalNormal(self.mean + inp.mean, torch.sqrt(torch.clip(new_cov, min = 1e-12)))
        elif isinstance(inp, torch.Tensor):
            return DiagonalNormal(self.mean + inp, self.scale)
        else:
            raise NotImplementedError('Distribution addition only implemented for diag covariances')

    def __matmul__(self, inp):
        assert inp.shape[-2] == self.loc.shape[-1]
        assert inp.shape[-1] == 1
        new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim = False)
        return DiagonalNormal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min = 1e-12)))

    def squeeze(self, idx: int):
        '''
        Squeeze the distribution along the specified dimension.
        '''
        return DiagonalNormal(self.loc.squeeze(idx), self.scale.squeeze(idx))


class DenseNormal(torch.distributions.MultivariateNormal):
    '''
    Creates a multivariate Gaussian distribution with a dense covariance matrix,
    parameterized by a mean vector and a covariance matrix
    (, the lower triangular Cholesky decomposition of the covariance matrix).

    Parameters
    -----------
    loc: torch.Tensor [n,d] or [d]
        the tensor of per-element means
        
    cholesky: torch.Tensor [n,d,d] or [d,d]
        the lower triangular Cholesky decomposition of the covariance matrix.
        The diagonal elements of the Cholesky decomposition matrix represent 
        the square root of the variances of the original variables (i.e., the standard deviations).
    
    '''
    def __init__(self, loc, cholesky):
        super(DenseNormal, self).__init__(loc, scale_tril=cholesky)

    @property
    def mean(self) -> torch.Tensor:
        '''
        Mean vector of the distribution.
        '''
        return self.loc

    @property
    def chol_covariance(self) -> torch.Tensor:
        '''
        The lower triangular Cholesky decomposition of the covariance matrix.
        '''
        return self.scale_tril

    @property
    def covariance(self) -> torch.Tensor:
        '''
        Covariance matrix.
        
        It is computed as the product of the Cholesky decomposition with its transpose.
        '''
        return self.scale_tril @ tp(self.scale_tril)

    @property
    def covariance_diagonal(self) -> torch.Tensor:
        '''
        Diagonal (vector) of the covariance matrix.
        '''
        return torch.diagonal(self.covariance, dim1=-2, dim2=-1)

    @property
    def inverse_covariance(self) -> torch.Tensor:
        '''
        Inverse of the covariance matrix.
        
        It is computed as the inverse of the Cholesky decomposition.
        '''
        warnings.warn("Direct matrix inverse for dense covariances is O(N^3), consider using eg inverse weighted inner product")
        a = torch.linalg.inv(self.scale_tril)
        return tp(a) @ a

    @property
    def logdet_covariance(self) -> torch.Tensor:
        '''
        Log determinant of the covariance matrix.
        '''
        return 2. * torch.diagonal(self.scale_tril, dim1=-2, dim2=-1).log().sum(-1)

    @property
    def trace_covariance(self) -> torch.Tensor:
        '''
        Trace of the covariance matrix.
        '''
        return (self.scale_tril**2).sum(-1).sum(-1) # compute as frob norm squared

    def covariance_weighted_inner_prod(self, b: torch.Tensor, reduce_dim=True) -> torch.Tensor:
        '''
        Compute the covariance weighted inner product of a tensor b with the covariance matrix,
        i.e., b^T @ Cov @ b.
        
        Parameters
        -----------
        b: torch.Tensor [n,d,1]
            The tensor to compute the covariance weighted inner product with.
            
        reduce_dim: bool
            Whether to reduce the last dimension of the result.
        '''
        assert b.shape[-1] == 1
        prod = ((tp(self.scale_tril) @ b)**2).sum(-2)
        return prod.squeeze(-1) if reduce_dim else prod

    def precision_weighted_inner_prod(self, b: torch.Tensor, reduce_dim=True) -> torch.Tensor:
        '''
        Compute the precision weighted inner product of a tensor b with the precision matrix,
        i.e., b^T @ Cov^{-1} @ b.
        
        Parameters
        -----------
        b: torch.Tensor [n,d,1]
            The tensor to compute the precision weighted inner product with.
            
        reduce_dim: bool
            Whether to reduce the last dimension of the result.
        '''        
        assert b.shape[-1] == 1
        prod = (torch.linalg.solve(self.scale_tril, b)**2).sum(-2)
        return prod.squeeze(-1) if reduce_dim else prod

    def __matmul__(self, inp):
        assert inp.shape[-2] == self.loc.shape[-1]
        assert inp.shape[-1] == 1
        new_cov = self.covariance_weighted_inner_prod(inp.unsqueeze(-3), reduce_dim = False)
        return DiagonalNormal(self.loc @ inp, torch.sqrt(torch.clip(new_cov, min = 1e-12)))

    def squeeze(self, idx: int):
        '''
        Squeeze the distribution along the specified dimension.
        '''        
        return DenseNormal(self.loc.squeeze(idx), self.scale_tril.squeeze(idx))

