"""
    3rd Party Libraries
"""
import numpy as np
from scipy.stats import multivariate_normal

"""
    Files
"""
from belief_assessment.distributions.base import Distribution


class GaussianDistribution(Distribution):
    """
        Multivariate Gaussian distribution.
    
        Simple implementation of a multivariate Gaussian distribution.
    """
    def __init__(self, mean, cov, name=None, seed=None):
        """
            Initialize a Gaussian distribution.
        
            Args:
                mean (np.ndarray): Mean vector
                cov (np.ndarray): Covariance matrix
                name (str, optional): Name identifier for the distribution
                seed (int, optional): Random seed for reproducibility
        """
        dim = len(mean)
        super().__init__(dim, name, seed)

        # & Validate parameters
        self.mean = np.asarray(mean, dtype=np.float64)
        self.cov = np.asarray(cov, dtype=np.float64)
        if self.mean.shape != (dim,):
            raise ValueError(f"Mean shape {self.mean.shape} doesn't match dim={dim}")
        if self.cov.shape != (dim, dim):
            raise ValueError(f"Covariance shape {self.cov.shape} doesn't match dim={dim}")
        
        # & Create scipy distribution for log_prob calculations
        self.dist = multivariate_normal(mean=self.mean, cov=self.cov)

        # & Precompute precision matrix for score function
        self.precision = np.linalg.inv(self.cov)


    def sample(self, n_samples):
        """
            Generate samples from the Gaussian distribution.
        
            Args:
                n_samples (int): Number of samples to generate
                
            Returns:
                np.ndarray: Generated samples with shape (n_samples, dim)
        """
        return np.random.multivariate_normal(
            mean=self.mean, 
            cov=self.cov, 
            size=n_samples
        )
    

    def log_prob(self, x):
        """
            Compute log probability density at the given points.
        
            Args:
                x (np.ndarray): Points at which to evaluate log probability
                
            Returns:
                np.ndarray: Log probability values
        """
        x = np.atleast_2d(x)
        return self.dist.logpdf(x)
    

    def score(self, x):
        """
            Compute score function (gradient of log probability) at the given points.
            
            Args:
                x (np.ndarray): Points at which to evaluate the score
                
            Returns:
                np.ndarray: Score vectors with shape (n_points, dim)
        """
        x = np.atleast_2d(x)
        return np.array([
            -np.dot(self.precision, (x_i - self.mean)) 
            for x_i in x
        ])