"""
    3rd Party Libraries
"""
import numpy as np
from scipy.stats import multivariate_normal

"""
    Files
"""
from belief_assessment.distributions import Distribution
from belief_assessment.utils import log_sum_exp


class GMMDistribution(Distribution):
    """
        Gaussian Mixture Model distribution.
    
        A mixture of multivariate Gaussians for representing multi-modal distributions.
    """
    def __init__(self, means, covs, weights, name=None, seed=None):
        """
            Initialize a Gaussian Mixture Model.
        
            Args:
                means (np.ndarray): Component means with shape (n_components, dim)
                covs (np.ndarray): Component covariance matrices with shape (n_components, dim, dim)
                weights (np.ndarray): Component weights with shape (n_components,)
                name (str, optional): Name identifier for the distribution
                seed (int, optional): Random seed for reproducibility
        """
        self.n_components = len(weights)
        dim = means.shape[1]
        super().__init__(dim, name, seed)

        # & Validate parameters
        self.means = np.asarray(means, dtype=np.float64)
        self.covs = np.asarray(covs, dtype=np.float64)
        self.weights = np.asarray(weights, dtype=np.float64)
        if self.means.shape != (self.n_components, dim):
            raise ValueError(f"Means shape {self.means.shape} doesn't match components={self.n_components}, dim={dim}")
        if self.covs.shape != (self.n_components, dim, dim):
            raise ValueError(f"Covariance shape {self.covs.shape} doesn't match components={self.n_components}, dim={dim}")
        if self.weights.shape != (self.n_components,):
            raise ValueError(f"Weights shape {self.weights.shape} doesn't match components={self.n_components}")
        if not np.isclose(np.sum(self.weights), 1.0):
            raise ValueError(f"Weights must sum to 1, got sum={np.sum(self.weights)}")
        
        # & Create component distributions
        self.components = []
        self.precisions = []
        for i in range(self.n_components):
            self.components.append(
                multivariate_normal(mean=self.means[i], cov=self.covs[i])
            )
            self.precisions.append(np.linalg.inv(self.covs[i]))


    def sample(self, n_samples):
        """
            Generate samples from the GMM distribution.
        
            Args:
                n_samples (int): Number of samples to generate
                
            Returns:
                np.ndarray: Generated samples with shape (n_samples, dim)
        """
        samples = np.zeros((n_samples, self.dim))
        component_indices = np.random.choice(
            self.n_components, size=n_samples, p=self.weights
        )

        for i in range(self.n_components):
            mask = component_indices == i
            n_from_component = np.sum(mask)
            if n_from_component > 0:
                samples[mask] = np.random.multivariate_normal(
                    mean=self.means[i], 
                    cov=self.covs[i], 
                    size=n_from_component
                )
        
        return samples
    

    def log_prob(self, x):
        """
            Compute log probability density at the given points.
            
            Args:
                x (np.ndarray): Points at which to evaluate log probability
                
            Returns:
                np.ndarray: Log probability values
        """
        x = np.atleast_2d(x)
        n_samples = x.shape[0]
        log_probs = np.zeros(n_samples)
        
        for i in range(n_samples):
            component_log_probs = np.array([
                np.log(self.weights[j]) + self.components[j].logpdf(x[i]) 
                for j in range(self.n_components)
            ])
            log_probs[i] = log_sum_exp(component_log_probs)
        
        return log_probs
    

    def score(self, x):
        """
        Compute score function (gradient of log probability) at the given points.
        
        Args:
            x (np.ndarray): Points at which to evaluate the score
            
        Returns:
            np.ndarray: Score vectors with shape (n_points, dim)
        """
        x = np.atleast_2d(x)
        n_samples = x.shape[0]
        scores = np.zeros_like(x)
        
        for i in range(n_samples):
            x_i = x[i]
            component_log_probs = np.array([
                np.log(self.weights[j]) + self.components[j].logpdf(x_i) 
                for j in range(self.n_components)
            ])
            max_log_prob = np.max(component_log_probs)
            component_probs = np.exp(component_log_probs - max_log_prob)
            normalized_probs = component_probs / np.sum(component_probs)
            
            score_i = np.zeros(self.dim)
            for j in range(self.n_components):
                score_i += normalized_probs[j] * np.dot(
                    self.precisions[j], 
                    (self.means[j] - x_i)
                )
            
            scores[i] = -score_i
        
        return scores
