import math
import torch
import numpy as np
from scipy.special import gamma, hyp1f1
from scipy.stats import ncx2
from scipy.linalg import sqrtm

class MultivariateNormalMixture:

    def __init__(self, means, covariances, weights):
        """
        Initializes a multivariate normal mixture model.

        Args:
            means (tensor[K, D]): Mean vectors for each of the K components.
            covariances (tensor[K, D, D]): Covariance matrices for each of the K components.
            weights (tensor[K]): Mixture weights (should sum to 1).
        """
        self.means = torch.as_tensor(means, dtype=torch.float32)
        self.covariances = torch.as_tensor(covariances, dtype=torch.float32)
        self.weights = torch.as_tensor(weights, dtype=torch.float32)
        self.K, self.D = self.means.shape

        # Ensure symmetric covariance
        self.covariances = 0.5 * (self.covariances + self.covariances.transpose(-1, -2))

        # Precompute for each component
        self.chol_covariances = torch.linalg.cholesky(self.covariances)
        self.cov_invs = torch.cholesky_inverse(self.chol_covariances)
        self.log_det_covs = 2 * torch.sum(torch.log(torch.diagonal(self.chol_covariances, dim1=-2, dim2=-1)), dim=1)

    def sample(self, n_samples):
        """
        Generates samples from the mixture model using ancestral sampling.

        Args:
            n_samples (int): Number of samples to generate.

        Returns:
            samples (tensor[n_samples, D]): Generated samples.
            component_indices (tensor[n_samples]): Indices of mixture components used.
        """
        # 1. Sample mixture component indices according to weights
        component_indices = torch.multinomial(self.weights, n_samples, replacement=True)  # [N]

        # 2. Prepare output container
        samples = torch.empty(n_samples, self.D, dtype=torch.float32)

        # 3. For each component, draw the appropriate number of samples
        for k in range(self.K):
            mask = (component_indices == k)
            n_k = mask.sum().item()
            if n_k > 0:
                dist = torch.distributions.MultivariateNormal(self.means[k], self.covariances[k])
                samples[mask] = dist.sample((n_k,))

        return samples, component_indices

    
    def component_log_prob(self, x):
        """
        Compute log probability for each component at each point in x.

        Args:
            x (tensor[N, D]): Points to evaluate.

        Returns:
            tensor[N, K]: Log-probabilities under each component.
        """
        x = x.unsqueeze(1)  # [N, 1, D]
        means = self.means.unsqueeze(0)  # [1, K, D]
        cov_invs = self.cov_invs.unsqueeze(0)  # [1, K, D, D]

        diff = x - means  # [N, K, D]
        # mahal = torch.einsum("nkd,nkd->nk", torch.matmul(diff, cov_invs), diff)
        mahal = torch.einsum("nkd,kde,nke->nk", diff, self.cov_invs, diff)
        log_probs = -0.5 * (self.D * np.log(2 * np.pi) + self.log_det_covs + mahal)
        return log_probs  # [N, K]

    def log_prob(self, x):
        """
        Log-probability under the full mixture model.

        Args:
            x (tensor[N, D])

        Returns:
            tensor[N]: Log-probability under the mixture.
        """
        log_probs = self.component_log_prob(x)  # [N, K]
        weighted_log_probs = log_probs + torch.log(self.weights)  # [N, K]
        return torch.logsumexp(weighted_log_probs, dim=1)  # [N]

    def score(self, x):
        """
        Score function: gradient of the log-probability w.r.t x.

        Args:
            x (tensor[N, D]): Inputs

        Returns:
            tensor[N, D]: Score for each input.
        """
        x = x.clone().detach().requires_grad_(True)
        logp = self.log_prob(x)
        score = torch.autograd.grad(logp.sum(), x)[0]
        return score
    
    def kernelized_stein_discrepancy(self, particles, kernel_type='rbf', bandwidth=None):
        """
        Computes the kernelized Stein discrepancy between the mixture density and empirical distribution.
        
        Args:
            particles (tensor[N, D]): Sample particles from empirical distribution
            kernel_type (str): Type of kernel ('rbf' or 'imq')
            bandwidth (float): Kernel bandwidth. If None, uses median heuristic
            
        Returns:
            float: Kernelized Stein discrepancy value
        """
        particles = torch.as_tensor(particles, dtype=torch.float32)
        N, D = particles.shape
        
        if bandwidth is None:
            # Median heuristic for bandwidth
            pdist = torch.pdist(particles)
            bandwidth = torch.median(pdist)
            if bandwidth == 0:
                bandwidth = 1.0
        
        # Compute score function using existing method
        score_particles = self.score(particles)  # [N, D]
        
        # Compute pairwise distances
        diff = particles.unsqueeze(1) - particles.unsqueeze(0)  # [N, N, D]
        sq_dist = torch.sum(diff**2, dim=2)  # [N, N]
        
        if kernel_type == 'rbf':
            # RBF kernel: k(x,y) = exp(-||x-y||²/(2h²))
            kernel_matrix = torch.exp(-sq_dist / (2 * bandwidth**2))  # [N, N]
            
            # Gradient of kernel w.r.t. first argument: ∇_x k(x,y) = -k(x,y)(x-y)/h²
            kernel_grad = -kernel_matrix.unsqueeze(2) * diff / bandwidth**2  # [N, N, D]
            
            # Hessian trace: ∇²_x k(x,y) = k(x,y) * (||x-y||²/h⁴ - D/h²)
            hess_trace = kernel_matrix * (sq_dist / bandwidth**4 - D / bandwidth**2)  # [N, N]
            
        elif kernel_type == 'imq':
            # Inverse Multi-Quadratic: k(x,y) = (c² + ||x-y||²)^(-β)
            c = bandwidth
            beta = 0.5
            
            base = c**2 + sq_dist  # [N, N]
            kernel_matrix = base**(-beta)  # [N, N]
            
            # Gradient: ∇_x k(x,y) = -β * k(x,y) * (x-y) / (c² + ||x-y||²)
            kernel_grad = -beta * kernel_matrix.unsqueeze(2) * diff / base.unsqueeze(2)  # [N, N, D]
            
            # Hessian trace: more complex for IMQ
            hess_trace = kernel_matrix * (-beta * D / base + beta * (beta + 1) * sq_dist / base**2)  # [N, N]
            
        else:
            raise ValueError(f"Unknown kernel type: {kernel_type}")
        
        # Compute KSD using Stein operator
        # KSD² = (1/N²) * Σᵢⱼ [∇ₓₚ(xᵢ)ᵀ∇ᵧₚ(xⱼ)k(xᵢ,xⱼ) + ∇ₓₚ(xᵢ)ᵀ∇ᵧk(xᵢ,xⱼ) + ∇ᵧₚ(xⱼ)ᵀ∇ₓk(xᵢ,xⱼ) + ∇ₓᵧk(xᵢ,xⱼ)]
        
        # Term 1: score-score interaction
        term1 = torch.sum(score_particles.unsqueeze(1) * score_particles.unsqueeze(0) * kernel_matrix.unsqueeze(2), dim=2)  # [N, N]
        
        # Term 2: score-kernel gradient interaction (first particle)
        term2 = torch.sum(score_particles.unsqueeze(1) * kernel_grad, dim=2)  # [N, N]
        
        # Term 3: kernel gradient-score interaction (second particle)
        term3 = torch.sum(kernel_grad * score_particles.unsqueeze(0), dim=2)  # [N, N]
        
        # Term 4: kernel Hessian trace
        term4 = hess_trace  # [N, N]
        
        # Sum all terms
        ksd_matrix = term1 + term2 + term3 + term4  # [N, N]
        
        # Average over all pairs
        ksd_squared = torch.mean(ksd_matrix)
        
        return torch.sqrt(torch.clamp(ksd_squared, min=0.0)).item()

    def hessian_tensor(self, x):
        """
        Hessian of log-prob at each point x.

        Args:
            x (tensor[N, D])

        Returns:
            tensor[N, D, D]
        """
        return -torch.vmap(torch.func.hessian(self.log_prob))(x)

    def hessian_mixture(self, x):
        """
        Average Hessian over a batch of inputs.

        Args:
            x (tensor[N, D])

        Returns:
            tensor[D, D]
        """
        return self.hessian_tensor(x).mean(dim=0)

    def mean_vector(self):
        """Computes the overall mixture mean."""
        return torch.sum(self.weights[:, None] * self.means, dim=0)  # shape [D]

    def marginal_covariance(self):
        """
        Computes the full marginal covariance matrix of the mixture.
        Returns:
            cov_total: tensor of shape [D, D]
        """
        mean_mix = self.mean_vector()  # [D]
        centered_means = self.means - mean_mix  # [K, D]
        
        covs = self.covariances  # [K, D, D]
        outer_products = torch.einsum("kd,ke->kde", centered_means, centered_means)  # [K, D, D]
        
        cov_total = torch.sum(self.weights[:, None, None] * (covs + outer_products), dim=0)  # [D, D]
        return cov_total

    def averaged_marginal_variance(self):
        """
        Computes the dimension-averaged marginal variance of the mixture.
        Returns:
            scalar: averaged variance
        """
        cov_total = self.marginal_covariance()  # [D, D]
        return torch.trace(cov_total) / self.D  # scalar
    
    # def averaged_marginal_variance_new(self):
    #     """Returns the dimension-averaged marginal variance of the mixture."""
    #     mu = torch.sum(self.weights[:, None] * self.means, dim=0)
    #     mean_diffs = self.means - mu  # (K, D)
    #     between = torch.sum(self.weights * torch.sum(mean_diffs**2, dim=1)) / self.D
    #     within = torch.mean(torch.stack([torch.trace(S) for S in self.covariances]) * self.weights)
    #     return within / self.D + between

    def scale_to_unit_marginal_variance(self):
        """
        Scales the component covariances so that the total dimension-averaged marginal variance is 1.
        """
        # Compute mixture mean
        mu = torch.sum(self.weights[:, None] * self.means, dim=0)  # (D,)

        # Compute between-component contribution to marginal variance
        mean_diffs = self.means - mu  # (K, D)
        squared_dists = torch.sum(mean_diffs ** 2, dim=1)  # (K,)
        between_var = torch.sum(self.weights * squared_dists) / self.D  # scalar

        # Compute required within-component contribution
        s2 = 1.0 - between_var
        if s2 <= 0:
            raise ValueError(f"Cannot scale: between-component variance is too large ({between_var:.4f}), "
                             f"leaving no room for within-component variance.")

        # Scale each covariance matrix by s2
        self.covariances *= s2

    def marginal_log_prob(self, x, dims):
        """
        Compute log-probability of the mixture marginalized onto a subset of dimensions.

        Args:
            x (tensor[N, len(dims)]): Points in the lower-dimensional space.
            dims (list or tensor): Indices of dimensions to keep (e.g., [0, 1]).

        Returns:
            tensor[N]: Log-probability under the marginal mixture.
        """
        # Extract the marginal means and covariances for the selected dims
        means_marginal = self.means[:, dims]            # [K, len(dims)]
        covs_marginal = self.covariances[:, dims][:, :, dims]  # [K, len(dims), len(dims)]
        D_marg = len(dims)
        K = self.K

        # Precompute inverses and log-dets for marginal
        chol_covs = torch.linalg.cholesky(covs_marginal)          # [K, D_marg, D_marg]
        cov_invs = torch.cholesky_inverse(chol_covs)             # [K, D_marg, D_marg]
        log_det_covs = 2 * torch.sum(torch.log(torch.diagonal(chol_covs, dim1=-2, dim2=-1)), dim=1)  # [K]

        # Compute log_prob for each component
        x = x.unsqueeze(1)               # [N, 1, D_marg]
        means = means_marginal.unsqueeze(0)  # [1, K, D_marg]
        diff = x - means                  # [N, K, D_marg]
        mahal = torch.einsum("nkd,kde,nke->nk", diff, cov_invs, diff)  # [N, K]
        log_probs = -0.5 * (D_marg * np.log(2*np.pi) + log_det_covs + mahal)  # [N, K]

        # Weighted mixture
        weighted_log_probs = log_probs + torch.log(self.weights)  # [N, K]
        return torch.logsumexp(weighted_log_probs, dim=1)         # [N]

    def energy_distance_mc(self, samples_Y, n_mc=1000):
        """
        Monte Carlo approximation of the energy distance between this mixture
        and a set of samples.

        Args:
            samples_Y (tensor[n_samples_Y, D]): Samples from the other distribution Q.
            n_mc (int): Number of Monte Carlo samples to draw from this mixture.

        Returns:
            energy (float): Approximated energy distance.
        """
        device = samples_Y.device if hasattr(samples_Y, "device") else "cpu"
        samples_Y = samples_Y.to(device)
        means = self.means.to(device)
        covs = self.covariances.to(device)
        weights = self.weights.to(device)

        # 1. Sample from this mixture
        X_samples, _ = self.sample(n_mc)  # already returns a tensor
        X_samples = X_samples.to(device)

        # 2. Compute pairwise distances
        # E[||X-X'||]
        d_xx = torch.cdist(X_samples, X_samples, p=2).mean()

        # E[||Y-Y'||]
        d_yy = torch.cdist(samples_Y, samples_Y, p=2).mean()

        # E[||X-Y||]
        d_xy = torch.cdist(X_samples, samples_Y, p=2).mean()

        energy = 2 * d_xy - d_xx - d_yy
        return energy.item()