import numpy as np
from scipy.linalg import qr
from scipy.stats import special_ortho_group, wasserstein_distance


class GSWD:
    """
    Generalized Sliced Wasserstein Distance implementation.
    
    The GSWD extends the standard Sliced Wasserstein Distance by replacing
    linear projections with more general nonlinear transformations or
    optimized projections that capture correlation structure.
    """
    
    def __init__(self, n_projections=50, projection_method='random', 
                optimization_steps=10, correlation_aware=True):
        """
        Initialize GSWD calculator
        
        Args:
            n_projections (int): Number of projection directions to use
            projection_method (str): Method for generating projections: 
                'random', 'pca', or 'optimized'
            optimization_steps (int): Number of steps for projection optimization
            correlation_aware (bool): Whether to consider correlation structure
                in projection generation/optimization
        """
        self.n_projections = n_projections
        self.projection_method = projection_method
        self.optimization_steps = optimization_steps
        self.correlation_aware = correlation_aware
        self.projections = None
        
    def _generate_random_projections(self, dim):
        """
        Generate random projection vectors on the unit sphere
        
        Args:
            dim (int): Dimensionality of the space
            
        Returns:
            np.ndarray: Set of projection directions with shape (n_projections, dim)
        """
        # & Generate random orthogonal matrix
        if dim > 1:
            # Create enough random projections to meet n_projections
            projections = np.random.normal(0, 1, (self.n_projections, dim))
            
            # Normalize each vector to unit length
            norms = np.sqrt(np.sum(projections**2, axis=1, keepdims=True))
            projections = projections / norms
            
            # For the first min(dim, n_projections) vectors, try to use orthogonal vectors
            n_ortho = min(dim, self.n_projections)
            if n_ortho > 1:
                try:
                    # Generate orthogonal matrix
                    random_orthogonal = special_ortho_group.rvs(dim)
                    projections[:n_ortho] = random_orthogonal[:n_ortho]
                except:
                    # If fails, keep the random vectors we already generated
                    pass
        else:
            # For 1D, use [1] and [-1] as projections, duplicated as needed
            projections = np.ones((self.n_projections, 1))
            projections[self.n_projections//2:] = -1
            
        # Ensure projections are normalized
        norms = np.sqrt(np.sum(projections**2, axis=1, keepdims=True))
        projections = projections / (norms + 1e-10)  # Avoid division by zero
        
        return projections
    
    def _generate_pca_projections(self, samples):
        """
        Generate projection vectors using PCA on the samples
        
        Args:
            samples (np.ndarray): Samples to analyze with shape (n_samples, dim)
            
        Returns:
            np.ndarray: Set of projection directions with shape (n_projections, dim)
        """
        # & Center the data
        centered = samples - np.mean(samples, axis=0, keepdims=True)
        
        # & Compute covariance matrix
        cov = np.cov(centered, rowvar=False)
        
        # & Compute eigenvalues and eigenvectors
        eigenvalues, eigenvectors = np.linalg.eigh(cov)
        
        # & Sort by eigenvalues in descending order
        idx = np.argsort(eigenvalues)[::-1]
        eigenvalues = eigenvalues[idx]
        eigenvectors = eigenvectors[:, idx]
        
        # & Determine how many PCA components we can extract
        dim = samples.shape[1]
        n_vectors = min(dim, self.n_projections)
        
        # & Initialize projections array of correct size
        projections = np.zeros((self.n_projections, dim))
        
        # & Fill with PCA components
        projections[:n_vectors] = eigenvectors[:, :n_vectors].T
        
        # & If we need more projections, fill the rest with random projections
        if self.n_projections > n_vectors:
            additional = self._generate_random_projections(dim)
            projections[n_vectors:] = additional[:(self.n_projections - n_vectors)]
            
        # Ensure projections are normalized
        norms = np.sqrt(np.sum(projections**2, axis=1, keepdims=True))
        projections = projections / (norms + 1e-10)  # Avoid division by zero
            
        return projections
    
    def _optimize_projections(self, samples_p, samples_q):
        """
        Optimize projection directions to maximize sliced Wasserstein distance
        
        Args:
            samples_p (np.ndarray): First set of samples with shape (n_samples_p, dim)
            samples_q (np.ndarray): Second set of samples with shape (n_samples_q, dim)
            
        Returns:
            np.ndarray: Optimized projection directions with shape (n_projections, dim)
        """
        # & Get dimensionality
        dim = samples_p.shape[1]
        
        # & Start with random or PCA projections
        if self.projection_method == 'pca':
            # & Use combined samples for PCA initialization
            combined = np.vstack([samples_p, samples_q])
            projections = self._generate_pca_projections(combined)
        else:
            projections = self._generate_random_projections(dim)
        
        # & Verify projections have the correct shape
        if projections.shape[0] != self.n_projections:
            print(f"Warning: Expected {self.n_projections} projections, got {projections.shape[0]}. Fixing...")
            if projections.shape[0] < self.n_projections:
                # Add more random projections if needed
                additional = self._generate_random_projections(dim)
                additional = additional[:(self.n_projections - projections.shape[0])]
                projections = np.vstack([projections, additional])
            else:
                # Truncate if we have too many
                projections = projections[:self.n_projections]
            
        # & If correlation-aware, initialize with correlation directions
        if self.correlation_aware and dim > 1:
            try:
                # & Compute correlation matrix for each sample set
                corr_p = np.corrcoef(samples_p, rowvar=False)
                corr_q = np.corrcoef(samples_q, rowvar=False)
                
                # & Find principal correlation differences
                corr_diff = corr_p - corr_q
                eigenvalues, eigenvectors = np.linalg.eigh(corr_diff)
                
                # & Use top eigenvectors (by absolute eigenvalue) for some projections
                idx = np.argsort(np.abs(eigenvalues))[::-1]
                eigenvalues = eigenvalues[idx]
                eigenvectors = eigenvectors[:, idx]
                
                # & Replace some projections with correlation-based ones
                n_corr = min(dim, self.n_projections // 3)  # Use about 1/3 for correlation
                if n_corr > 0:
                    projections[:n_corr] = eigenvectors[:, :n_corr].T
            except Exception as e:
                print(f"Warning: Could not use correlation-aware projections: {e}")
        
        # & Verify projections have the correct shape again after correlation adjustment
        if projections.shape[0] != self.n_projections:
            # Add more random projections or truncate as needed
            if projections.shape[0] < self.n_projections:
                additional = self._generate_random_projections(dim)
                additional = additional[:(self.n_projections - projections.shape[0])]
                projections = np.vstack([projections, additional])
            else:
                projections = projections[:self.n_projections]
        
        # & Optimize projections with gradient ascent
        learning_rate = 0.1
        for step in range(self.optimization_steps):
            # & Forward pass: compute projected samples and distances
            distances = np.zeros(self.n_projections)
            projected_p_list = []
            projected_q_list = []
            
            for i in range(self.n_projections):
                # & Project samples onto projection direction
                projected_p = np.dot(samples_p, projections[i])
                projected_q = np.dot(samples_q, projections[i])
                
                # & Compute 1D Wasserstein distance
                distances[i] = wasserstein_distance(projected_p, projected_q)
                
                # & Store projections for gradient computation
                projected_p_list.append(projected_p)
                projected_q_list.append(projected_q)
                
            # & Compute gradient of Wasserstein distance w.r.t. projection
            # & We'll use a simple numerical approximation for the gradient
            epsilon = 1e-6
            gradients = np.zeros_like(projections)
            
            for i in range(self.n_projections):
                for j in range(dim):
                    # & Create perturbed projection
                    perturbed = projections[i].copy()
                    perturbed[j] += epsilon
                    perturbed = perturbed / np.linalg.norm(perturbed)
                    
                    # & Project samples onto perturbed direction
                    projected_p_perturbed = np.dot(samples_p, perturbed)
                    projected_q_perturbed = np.dot(samples_q, perturbed)
                    
                    # & Compute perturbed distance
                    distance_perturbed = wasserstein_distance(
                        projected_p_perturbed, projected_q_perturbed)
                    
                    # & Approximate gradient
                    gradients[i, j] = (distance_perturbed - distances[i]) / epsilon
            
            # & Update projections with gradient ascent (maximize distance)
            projections += learning_rate * gradients
            
            # & Re-normalize projection vectors
            norms = np.sqrt(np.sum(projections**2, axis=1, keepdims=True))
            projections = projections / (norms + 1e-10)  # Avoid division by zero
            
            # & Decay learning rate
            learning_rate *= 0.9
        
        return projections
    
    def _project_samples(self, samples, projections=None):
        """
        Project samples onto the set of projection directions
        
        Args:
            samples (np.ndarray): Samples to project with shape (n_samples, dim)
            projections (np.ndarray, optional): Projection directions with 
                shape (n_projections, dim). If None, use stored projections.
                
        Returns:
            np.ndarray: Projected samples with shape (n_projections, n_samples)
        """
        if projections is None:
            if self.projections is None:
                raise ValueError("No projections available. Call fit() first.")
            projections = self.projections
        
        # & Project samples onto each projection vector
        n_projections = projections.shape[0]
        n_samples = samples.shape[0]
        
        projected = np.zeros((n_projections, n_samples))
        
        for i in range(n_projections):
            projected[i] = np.dot(samples, projections[i])
            
        return projected
    
    def fit(self, samples_p, samples_q):
        """
        Compute and store optimal projections for the given samples
        
        Args:
            samples_p (np.ndarray): First set of samples with shape (n_samples_p, dim)
            samples_q (np.ndarray): Second set of samples with shape (n_samples_q, dim)
        """
        try:
            if self.projection_method == 'optimized':
                self.projections = self._optimize_projections(samples_p, samples_q)
            elif self.projection_method == 'pca':
                combined = np.vstack([samples_p, samples_q])
                self.projections = self._generate_pca_projections(combined)
            else:  # random projections
                dim = samples_p.shape[1]
                self.projections = self._generate_random_projections(dim)
                
            # Final safety check to ensure projections have the correct shape
            dim = samples_p.shape[1]
            if self.projections.shape != (self.n_projections, dim):
                print(f"Warning: Projections have incorrect shape: {self.projections.shape}, expected {(self.n_projections, dim)}. Fixing...")
                
                # Create a new projections array of the correct shape
                new_projections = np.zeros((self.n_projections, dim))
                
                # Copy as many projections as we can
                common_rows = min(self.projections.shape[0], self.n_projections)
                new_projections[:common_rows] = self.projections[:common_rows]
                
                # Fill any remaining rows with random projections
                if common_rows < self.n_projections:
                    random_projections = self._generate_random_projections(dim)
                    new_projections[common_rows:] = random_projections[:(self.n_projections - common_rows)]
                
                # Ensure all projections are normalized
                norms = np.sqrt(np.sum(new_projections**2, axis=1, keepdims=True))
                new_projections = new_projections / (norms + 1e-10)  # Avoid division by zero
                
                self.projections = new_projections
        except Exception as e:
            # If any error occurs, fall back to random projections
            print(f"Error in GSWD.fit: {e}. Falling back to random projections.")
            dim = samples_p.shape[1]
            self.projections = self._generate_random_projections(dim)
    
    def compute_distance(self, samples_p, samples_q, projections=None, p=1, return_per_projection=False):
        """
        Compute the Generalized Sliced Wasserstein Distance
        
        Args:
            samples_p (np.ndarray): First set of samples with shape (n_samples_p, dim)
            samples_q (np.ndarray): Second set of samples with shape (n_samples_q, dim)
            projections (np.ndarray, optional): Projection directions to use.
                If None, use stored or generate new projections.
            p (int): Power parameter for the Wasserstein distance. Default is 1.
            return_per_projection (bool): If True, also return distances per projection
            
        Returns:
            float or tuple: GSWD value, or (GSWD, per_projection_distances) if
                return_per_projection is True
        """
        # & Ensure we have projections
        if projections is None:
            if self.projections is None:
                self.fit(samples_p, samples_q)
            projections = self.projections
        
        # & Project samples
        projected_p = self._project_samples(samples_p, projections)
        projected_q = self._project_samples(samples_q, projections)
        
        # & Compute distances for each projection
        distances = np.zeros(projections.shape[0])
        
        for i in range(projections.shape[0]):
            # & Calculate p-Wasserstein distance for 1D projected samples
            if p == 1:
                distances[i] = wasserstein_distance(projected_p[i], projected_q[i])
            else:
                # & Sort projected samples for p-Wasserstein distance calculation
                sorted_p = np.sort(projected_p[i])
                sorted_q = np.sort(projected_q[i])
                
                # & Compute p-Wasserstein distance
                if p == 2:
                    # & Optimize for p=2 case
                    distances[i] = np.sqrt(np.mean((sorted_p - sorted_q) ** 2))
                else:
                    distances[i] = np.mean(np.abs(sorted_p - sorted_q) ** p) ** (1.0 / p)
        
        # & Compute mean distance across projections
        gswd = np.mean(distances)
        
        if return_per_projection:
            return gswd, distances
        return gswd
    
    def compute_max_distance(self, samples_p, samples_q, projections=None, p=1):
        """
        Compute the Maximum Sliced Wasserstein Distance (max-SWD)
        
        Args:
            samples_p (np.ndarray): First set of samples with shape (n_samples_p, dim)
            samples_q (np.ndarray): Second set of samples with shape (n_samples_q, dim)
            projections (np.ndarray, optional): Projection directions to use.
                If None, use stored or generate new projections.
            p (int): Power parameter for the Wasserstein distance. Default is 1.
            
        Returns:
            tuple: (max_distance, max_projection_index)
        """
        # & Compute distance per projection
        gswd, distances = self.compute_distance(
            samples_p, samples_q, projections, p, return_per_projection=True)
        
        # & Find maximum distance and corresponding projection
        max_idx = np.argmax(distances)
        max_distance = distances[max_idx]
        
        return max_distance, max_idx
    
    def get_gradient(self, samples_p, samples_q, samples_x):
        """
            Compute gradient of GSWD w.r.t. samples_x (for use in SVGD updates)
        """
        # Ensure we have projections
        if self.projections is None:
            self.fit(samples_p, samples_q)
        
        # Project all samples
        projected_p = self._project_samples(samples_p)
        projected_q = self._project_samples(samples_q)
        projected_x = self._project_samples(samples_x)
        
        # Initialize gradient
        n_x = samples_x.shape[0]
        dim = samples_x.shape[1]
        gradients = np.zeros((n_x, dim))
        
        # Compute gradient for each projection and accumulate
        n_proj = self.projections.shape[0]
        for i in range(n_proj):
            sorted_p = np.sort(projected_p[i])
            sorted_q = np.sort(projected_q[i])
            
            for j in range(n_x):
                # Find where this point would rank in the sorted arrays
                rank_x = np.searchsorted(sorted_q, projected_x[i, j])
                rank_x = np.clip(rank_x, 0, len(sorted_p) - 1)
                
                # Get the corresponding point in the other distribution
                target_p = sorted_p[rank_x]
                
                # Gradient is proportional to the difference, not just the sign
                # Scale by a small factor to match numerical approximation scale
                grad_scale = 0.002  # This scaling factor needs to be tuned
                gradients[j] += grad_scale * (target_p - projected_x[i, j]) * self.projections[i]
        
        # Average over all projections
        gradients /= n_proj
        
        return gradients
    

    def get_regularizer(self, samples_p, samples_x, lambda_reg=0.1):
        """
        Compute GSWD regularization term for SVGD updates
        
        Args:
            samples_p (np.ndarray): Target distribution samples with shape (n_samples_p, dim)
            samples_x (np.ndarray): Current distribution samples with shape (n_x, dim)
            lambda_reg (float): Regularization strength
            
        Returns:
            np.ndarray: Regularization gradient with shape (n_x, dim)
        """
        # & Compute gradient of GSWD
        gswd_gradient = self.get_gradient(samples_p, samples_x, samples_x)
        
        # & Scale by regularization strength
        regularizer = lambda_reg * gswd_gradient
        
        return regularizer
