import torch
import torch.nn as nn
import numpy as np
from sklearn.mixture import GaussianMixture
from scipy.spatial.distance import pdist

class GMM_SteinSampler:
    def __init__(self, num_particles=50, bandwidth=0.5, learning_rate=0.01, 
                 num_iterations=100, n_components=5, covariance_type='full',
                 device='cuda' if torch.cuda.is_available() else 'cpu'):
        """
        GMM + Stein Sampling using SVGD framework (PyTorch GPU version)

        Args:
            num_particles (int): Number of particles to sample.
            bandwidth (float): Bandwidth for RBF kernel in SVGD.
            learning_rate (float): Step size for particle updates.
            num_iterations (int): Number of SVGD iterations.
            n_components (int): Number of Gaussian components in GMM.
            covariance_type (str): Type of covariance parameters ('full', 'tied', 'diag', 'spherical').
            device (str): Device to run computations on ('cuda' or 'cpu').
        """
        self.num_particles = num_particles
        self.bandwidth = bandwidth
        self.learning_rate = learning_rate
        self.num_iterations = num_iterations
        self.n_components = n_components
        self.covariance_type = covariance_type
        self.device = torch.device(device)
        self.particles = None
        self.gmm = None
        
        # GMM parameters as tensors for GPU computation
        self.weights = None
        self.means = None
        self.covariances = None
        self.precision_chol = None

    def _fit_gmm(self, data_np):
        """
        Fit GMM to data using sklearn and convert parameters to PyTorch tensors.
        
        Args:
            data_np (np.ndarray): Training data of shape (n_samples, dim)
        """
        self.gmm = GaussianMixture(
            n_components=self.n_components,
            covariance_type=self.covariance_type,
            random_state=42
        )
        self.gmm.fit(data_np)
        
        # Convert GMM parameters to PyTorch tensors
        self.weights = torch.tensor(self.gmm.weights_, dtype=torch.float32, device=self.device)
        self.means = torch.tensor(self.gmm.means_, dtype=torch.float32, device=self.device)
        
        if self.covariance_type == 'full':
            self.covariances = torch.tensor(self.gmm.covariances_, dtype=torch.float32, device=self.device)
            # Compute precision matrices for efficiency
            self.precision_matrices = torch.inverse(self.covariances)
        elif self.covariance_type == 'diag':
            self.covariances = torch.tensor(self.gmm.covariances_, dtype=torch.float32, device=self.device)
            self.precision_matrices = 1.0 / self.covariances
        elif self.covariance_type == 'spherical':
            self.covariances = torch.tensor(self.gmm.covariances_, dtype=torch.float32, device=self.device)
            self.precision_matrices = 1.0 / self.covariances.unsqueeze(-1)
        else:  # tied
            self.covariances = torch.tensor(self.gmm.covariances_, dtype=torch.float32, device=self.device)
            self.precision_matrices = torch.inverse(self.covariances).unsqueeze(0)

    def _log_prob_single_gaussian(self, X, mean, precision, cov_det):
        """
        Compute log probability for a single Gaussian component.
        
        Args:
            X (torch.Tensor): Data points of shape (n, d)
            mean (torch.Tensor): Mean vector of shape (d,)
            precision (torch.Tensor): Precision matrix
            cov_det (torch.Tensor): Determinant of covariance matrix
            
        Returns:
            torch.Tensor: Log probabilities of shape (n,)
        """
        n, d = X.shape
        diff = X - mean.unsqueeze(0)  # (n, d)
        
        if self.covariance_type == 'full':
            # Mahalanobis distance: (x-μ)^T Σ^(-1) (x-μ)
            mahal_dist = torch.sum(diff @ precision * diff, dim=1)
            log_det = torch.logdet(precision)
        elif self.covariance_type == 'diag':
            mahal_dist = torch.sum(diff**2 * precision.unsqueeze(0), dim=1)
            log_det = torch.sum(torch.log(precision))
        elif self.covariance_type == 'spherical':
            mahal_dist = torch.sum(diff**2, dim=1) * precision.squeeze()
            log_det = d * torch.log(precision.squeeze())
        else:  # tied
            mahal_dist = torch.sum(diff @ precision.squeeze(0) * diff, dim=1)
            log_det = torch.logdet(precision.squeeze(0))
        
        # Multivariate Gaussian log probability
        log_prob = -0.5 * (d * torch.log(torch.tensor(2 * np.pi, device=self.device)) 
                          - log_det + mahal_dist)
        print(log_prob)
        
        return log_prob

    def _log_prob(self, X):
        """
        Compute log density of X using GMM.
        
        Args:
            X (torch.Tensor): Shape (n, d)
            
        Returns:
            torch.Tensor: Log probabilities of shape (n,)
        """
        n, d = X.shape
        log_probs = torch.zeros(n, self.n_components, device=self.device)
        
        # Compute log probability for each component
        for k in range(self.n_components):
            if self.covariance_type == 'tied':
                cov_det = torch.det(self.covariances)
                log_probs[:, k] = self._log_prob_single_gaussian(
                    X, self.means[k], self.precision_matrices[0], cov_det
                )
            else:
                cov_det = torch.det(self.covariances[k])
                log_probs[:, k] = self._log_prob_single_gaussian(
                    X, self.means[k], self.precision_matrices[k], cov_det
                )
        
        # Add log weights and use logsumexp for numerical stability
        log_probs += torch.log(self.weights).unsqueeze(0)
        log_density = torch.logsumexp(log_probs, dim=1)
        print(log_density)
        
        return log_density

    def _grad_log_prob(self, X):
        """
        Compute gradient of log p(x) analytically using GMM.
        
        Args:
            X (torch.Tensor): Shape (n, d)
            
        Returns:
            torch.Tensor: Gradients of shape (n, d)
        """
        n, d = X.shape
        
        # Compute responsibilities (posterior probabilities)
        log_probs = torch.zeros(n, self.n_components, device=self.device)
        for k in range(self.n_components):
            if self.covariance_type == 'tied':
                cov_det = torch.det(self.covariances)
                log_probs[:, k] = self._log_prob_single_gaussian(
                    X, self.means[k], self.precision_matrices[0], cov_det
                )
            else:
                cov_det = torch.det(self.covariances[k])
                log_probs[:, k] = self._log_prob_single_gaussian(
                    X, self.means[k], self.precision_matrices[k], cov_det
                )
        
        log_probs += torch.log(self.weights).unsqueeze(0)
        log_sum = torch.logsumexp(log_probs, dim=1, keepdim=True)
        responsibilities = torch.exp(log_probs - log_sum)  # (n, k)
        
        # Compute gradient
        grads = torch.zeros_like(X)
        for k in range(self.n_components):
            diff = X - self.means[k].unsqueeze(0)  # (n, d)
            
            if self.covariance_type == 'full':
                grad_k = -diff @ self.precision_matrices[k]  # (n, d)
            elif self.covariance_type == 'diag':
                grad_k = -diff * self.precision_matrices[k].unsqueeze(0)  # (n, d)
            elif self.covariance_type == 'spherical':
                grad_k = -diff * self.precision_matrices[k].squeeze()  # (n, d)
            else:  # tied
                grad_k = -diff @ self.precision_matrices[0]  # (n, d)
            
            grads += responsibilities[:, k].unsqueeze(1) * grad_k
        
        return grads

    def _rbf_kernel(self, X, Y=None):
        """
        Compute RBF kernel matrix using PyTorch.
        
        Args:
            X (torch.Tensor): Shape (n, d)
            Y (torch.Tensor): Shape (m, d), if None uses X
            
        Returns:
            torch.Tensor: Kernel matrix of shape (n, m)
        """
        if Y is None:
            Y = X
        
        # Compute pairwise squared distances
        X_norm = (X ** 2).sum(dim=1, keepdim=True)  # (n, 1)
        Y_norm = (Y ** 2).sum(dim=1, keepdim=True)  # (m, 1)
        
        dist_sq = X_norm + Y_norm.T - 2 * torch.mm(X, Y.T)  # (n, m)
        
        # RBF kernel
        gamma = 1.0 / (2 * self.bandwidth ** 2)
        K = torch.exp(-gamma * dist_sq)
        
        return K

    def _compute_kernel_grad(self, X):
        """
        Compute gradient of RBF kernel.
        
        Args:
            X (torch.Tensor): Particles of shape (n, d)
            
        Returns:
            torch.Tensor: Kernel gradients of shape (n, d)
        """
        n, d = X.shape
        K = self._rbf_kernel(X)  # (n, n)
        
        # Compute kernel gradient
        gamma = 1.0 / (2 * self.bandwidth ** 2)
        
        # Expand dimensions for broadcasting
        X_i = X.unsqueeze(1)  # (n, 1, d)
        X_j = X.unsqueeze(0)  # (1, n, d)
        
        # Compute differences
        diff = X_i - X_j  # (n, n, d)
        
        # Kernel gradient for each particle
        K_expanded = K.unsqueeze(-1)  # (n, n, 1)
        grad_K_pairwise = -2 * gamma * K_expanded * diff  # (n, n, d)
        
        # Sum over all interactions for each particle
        grad_K = grad_K_pairwise.sum(dim=1)  # (n, d)
        
        return grad_K

    def _svgd_update(self):
        """
        Perform SVGD update using GMM for log p(x).
        """
        fudge_factor = 1e-6
        alpha = 0.9
        historical_phi = torch.zeros_like(self.particles)
        
        for iter in range(self.num_iterations):
            # Compute log p(x) gradient for each particle (analytical)
            grad_logp = self._grad_log_prob(self.particles)
            
            # Compute kernel and its gradient
            K = self._rbf_kernel(self.particles)  # (n, n)
            grad_K = self._compute_kernel_grad(self.particles)  # (n, d)
            
            # SVGD update: phi = (K @ grad_logp + grad_K) / num_particles
            phi = (torch.mm(K, grad_logp) + grad_K) / self.num_particles
            
            # AdaGrad-style adaptive learning rate
            if iter == 0:
                historical_phi = phi ** 2
            else:
                historical_phi = alpha * historical_phi + (1 - alpha) * (phi ** 2)
            
            adj_phi = phi / (fudge_factor + torch.sqrt(historical_phi))
            
            # Update particles
            self.particles += self.learning_rate * adj_phi

    @staticmethod
    def median_heuristic_bandwidth(particles: torch.Tensor):

        diff = particles.unsqueeze(1) - particles.unsqueeze(0)  # [N, N, D]
        dist_matrix = torch.norm(diff, dim=2)  # [N, N]

        # 取上三角非对角元素（避免自己和自己）
        N = particles.shape[0]
        triu_indices = torch.triu_indices(N, N, offset=1)
        dists = dist_matrix[triu_indices[0], triu_indices[1]]

        # 中位数
        median_dist = torch.median(dists)

        # 带宽公式
        h = (median_dist ** 2) / torch.log(torch.tensor(N, dtype=particles.dtype, device=particles.device))
        return h

    def fit(self, data):
        """
        Fit GMM on data and sample particles using SVGD.

        Args:
            data (np.ndarray or torch.Tensor): Shape (n_samples, dim)
            
        Returns:
            torch.Tensor: Sampled particles of shape (num_particles, dim)
        """
        # Convert data to numpy for GMM fitting
        if isinstance(data, torch.Tensor):
            data_np = data.detach().cpu().numpy()
        else:
            data_np = data
        
        # Convert data to tensor and move to device
        if not isinstance(data, torch.Tensor):
            data_tensor = torch.tensor(data, dtype=torch.float32, device=self.device)
        else:
            data_tensor = data.to(self.device)
        
        # Fit GMM to data
        self._fit_gmm(data_np)
        
        # Initialize particles from data randomly
        n_samples = data_tensor.shape[0]
        idx = torch.randperm(n_samples)[:self.num_particles]
        self.particles = data_tensor[idx].clone().requires_grad_(False)
        if self.bandwidth is None:
            self.bandwidth = self.median_heuristic_bandwidth(self.particles)
        
        # Perform SVGD updates
        self._svgd_update()
        
        return self.particles

    def sample(self, n_samples=None):
        """
        Return current particles as samples.
        
        Args:
            n_samples (int): Number of samples to return. If None, return all particles.
            
        Returns:
            torch.Tensor: Samples of shape (n_samples, dim)
        """
        if self.particles is None:
            raise ValueError("Must call fit() first before sampling.")
        
        if n_samples is None:
            return self.particles.clone()
        else:
            idx = torch.randperm(self.num_particles)[:n_samples]
            return self.particles[idx].clone()

    def sample_from_gmm(self, n_samples):
        """
        Sample directly from the fitted GMM (without SVGD).
        
        Args:
            n_samples (int): Number of samples to generate.
            
        Returns:
            torch.Tensor: Samples of shape (n_samples, dim)
        """
        if self.gmm is None:
            raise ValueError("Must call fit() first before sampling from GMM.")
        
        samples_np, _ = self.gmm.sample(n_samples)
        return torch.tensor(samples_np, dtype=torch.float32, device=self.device)

    def to(self, device):
        """
        Move the sampler to a different device.
        
        Args:
            device (str or torch.device): Target device
        """
        self.device = torch.device(device)
        if self.particles is not None:
            self.particles = self.particles.to(self.device)
        if self.weights is not None:
            self.weights = self.weights.to(self.device)
            self.means = self.means.to(self.device)
            self.covariances = self.covariances.to(self.device)
            self.precision_matrices = self.precision_matrices.to(self.device)
        return self

'''
# Example usage:
if __name__ == "__main__":
    # Generate some sample data
    np.random.seed(42)
    torch.manual_seed(42)
    
    # Create a mixture of Gaussians in higher dimension
    dim = 10
    data1 = np.random.normal(-2, 1, (200, dim))
    data2 = np.random.normal(2, 1, (200, dim))
    data3 = np.random.normal([0]*dim, 0.5, (200, dim))
    data = np.vstack([data1, data2, data3])
    
    # Initialize sampler
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    sampler = GMM_SteinSampler(
        num_particles=100,
        bandwidth=1.0,
        learning_rate=0.01,
        num_iterations=50,
        n_components=3,
        covariance_type='full',
        device=device
    )
    
    print(f"Using device: {device}")
    print(f"Data shape: {data.shape}")
    
    # Fit and sample
    particles = sampler.fit(data)
    
    print(f"Generated {particles.shape[0]} particles with shape {particles.shape}")
    print(f"Particles are on device: {particles.device}")
    
    # Get samples from SVGD
    svgd_samples = sampler.sample(50)
    print(f"SVGD samples shape: {svgd_samples.shape}")
    
    # Get samples directly from GMM
    gmm_samples = sampler.sample_from_gmm(50)
    print(f"GMM samples shape: {gmm_samples.shape}")
    
    # Compare log probabilities
    with torch.no_grad():
        log_prob_particles = sampler._log_prob(particles[:10])
        print(f"Log probabilities of first 10 particles: {log_prob_particles}")
'''