"""
    3rd Party Libraries
"""
import matplotlib.cm as cm
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import Ellipse
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import numpy as np
import seaborn as sns
from sklearn.decomposition import PCA
import umap


class GMMVisualizer:
    """
        Visualizer class for Gaussian Mixture Models.
    
        This class provides methods to create publication-quality visualizations
        of GMM distributions with support for 1D, 2D, 3D, and high-dimensional data.
        All visualizations are designed to be suitable for academic papers.
    """
    def __init__(self, cmap='viridis', figsize=None, dpi=300):
        """
            Initialize the GMM visualizer.
        
            Args:
                cmap (str): Colormap to use for visualizations
                figsize (tuple, optional): Default figure size
                dpi (int): Resolution for saved figures
        """
        self.cmap = cmap
        self.dpi = dpi

        # & Default figure sizes for different dimensions
        self._default_figsizes = {
            '1d': (10, 6),
            '2d': (12, 10),
            '3d': (15, 12),
            'high_dim': (15, 15)
        }
        # & Override defaults if provided
        if figsize is not None:
            for key in self._default_figsizes:
                self._default_figsizes[key] = figsize

        # & Set default style for publication-quality plots
        plt.style.use('seaborn-v0_8-whitegrid')
        sns.set_context("paper", font_scale=1.5)


    def visualize(self, gmm, plot_type=None, **kwargs):
        """
            Main visualization method that selects the appropriate visualization based on dimensionality.
        
            Args:
                gmm: The GMM distribution to visualize
                plot_type (str, optional): Force a specific plot type. Options: 
                    '1d', '2d', '3d', 'pca', 'tsne', 'pairplot'
                **kwargs: Additional arguments for the specific visualization method
                
            Returns:
                fig: Matplotlib figure object
        """
        # & Determine the appropriate visualization based on dimensionality
        if plot_type is None:
            if gmm.dim == 1:
                plot_type = '1d'
            elif gmm.dim == 2:
                plot_type = '2d'
            elif gmm.dim == 3:
                plot_type = '3d'
            elif gmm.dim <= 10:
                plot_type = 'pairplot'
            else:
                plot_type = 'pca'

        # & Set default figsize if not provided
        if 'figsize' not in kwargs:
            if plot_type in ['1d', '2d', '3d']:
                kwargs['figsize'] = self._default_figsizes[plot_type]
            else:
                kwargs['figsize'] = self._default_figsizes['high_dim']

        # & Call the appropriate visualization method
        if plot_type == '1d':
            return self.visualize_1d(gmm, **kwargs)
        elif plot_type == '2d':
            return self.visualize_2d(gmm, **kwargs)
        elif plot_type == '3d':
            return self.visualize_3d(gmm, **kwargs)
        elif plot_type == 'pairplot':
            return self.visualize_pairplot(gmm, **kwargs)
        elif plot_type == 'umap':
            kwargs['method'] = 'umap'
            return self.visualize_high_dim(gmm, **kwargs)
        elif plot_type == 'pca':
            kwargs['method'] = 'pca'
            return self.visualize_high_dim(gmm, **kwargs)
        else:
            raise ValueError(f"Unknown plot type: {plot_type}")
        

    def visualize_1d(self, gmm, x_range=None, n_points=1000, n_samples=500, 
                     show_components=True, title=None, figsize=None, 
                     cmap=None, alpha=0.7, save_path=None):
        """
            Visualize a 1D Gaussian Mixture Model.
        
            Args:
                gmm: The GMM distribution to visualize
                x_range (tuple, optional): Range for x-axis as (min, max)
                n_points (int): Number of points for density evaluation
                n_samples (int): Number of samples to draw
                show_components (bool): Whether to show individual components
                title (str, optional): Plot title
                figsize (tuple, optional): Figure size
                cmap (str, optional): Colormap for component visualization
                alpha (float): Transparency for histograms
                save_path (str, optional): Path to save the figure
                
            Returns:
                fig, ax: Matplotlib figure and axis objects
        """
        # & Validate input dimensionality
        if gmm.dim != 1:
            raise ValueError(f"This method only visualizes 1D distributions, but got {gmm.dim}D")
        
        # & Use instance defaults if not specified
        if cmap is None:
            cmap = self.cmap
        if figsize is None:
            figsize = self._default_figsizes['1d']

        # & Create figure and axis
        fig, ax = plt.subplots(figsize=figsize)

        # & Determine x range if not provided
        if x_range is None:
            # & Calculate range covering all components (mean ± 3*std)
            means = gmm.means.flatten()
            stds = np.sqrt(gmm.covs.reshape(gmm.n_components))
            x_min = min(means - 3 * stds)
            x_max = max(means + 3 * stds)
            x_range = (x_min, x_max)

        # & Generate x points for density evaluation
        x = np.linspace(x_range[0], x_range[1], n_points)
        x_reshaped = x.reshape(-1, 1)

        # & Calculate PDF values
        log_probs = gmm.log_prob(x_reshaped)
        probs = np.exp(log_probs)

        # & Draw samples
        samples = gmm.sample(n_samples)

        # & Plot histogram of samples
        ax.hist(samples, bins=30, density=True, alpha=alpha, color='gray', label='Samples')
        
        # & Plot the overall density
        ax.plot(x, probs, linewidth=3, color='black', label='GMM Density')

        # & Plot individual components if requested
        if show_components and gmm.n_components > 1:
            colors = cm.get_cmap(cmap)(np.linspace(0, 1, gmm.n_components))
            for i in range(gmm.n_components):
                # & Create a single Gaussian component
                mean = gmm.means[i, 0] 
                std = np.sqrt(gmm.covs[i, 0, 0]) 
                weight = gmm.weights[i]
                
                # & Calculate component density (scaled by weight)
                component_density = weight * (1 / (std * np.sqrt(2 * np.pi))) * \
                                    np.exp(-0.5 * ((x - mean) / std) ** 2)
                
                # & Plot the component
                ax.plot(x, component_density, '--', color=colors[i], 
                       linewidth=2, label=f'Component {i+1}')
                
        # & Add labels and title
        ax.set_xlabel('x')
        ax.set_ylabel('Density')
        if title:
            ax.set_title(title)
        else:
            ax.set_title(f'1D GMM with {gmm.n_components} Components')
        
        # & Add legend
        ax.legend(loc='best')
        
        # & Adjust layout
        plt.tight_layout()
        
        # & Save if requested
        if save_path:
            plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        
        return fig, ax
        

    def visualize_2d(self, gmm, x_range=None, y_range=None, n_points=100, n_samples=1000, 
                 show_components=True, title=None, figsize=None, save_path=None):
        """
            Visualize a 2D Gaussian Mixture Model with contour plots and samples.
        
            Args:
                gmm: The GMM distribution to visualize
                x_range (tuple, optional): Range for x-axis as (min, max)
                y_range (tuple, optional): Range for y-axis as (min, max)
                n_points (int): Number of points for density evaluation grid (per dimension)
                n_samples (int): Number of samples to draw
                show_components (bool): Whether to show ellipses for individual components
                title (str, optional): Plot title
                figsize (tuple, optional): Figure size
                cmap (str, optional): Colormap for density visualization
                alpha (float): Transparency for sample points
                save_path (str, optional): Path to save the figure
                custom_style (bool): Whether to use custom styling for the plot
            Returns:
                fig, ax: Matplotlib figure and axis objects
        """
        # & Validate input dimensionality
        if gmm.dim != 2:
            raise ValueError(f"This method only visualizes 2D distributions, but got {gmm.dim}D")
        
        # & Use instance defaults if not specified
        if figsize is None:
            figsize = self._default_figsizes['2d']

        # & Create colormaps if not already created
        if not hasattr(self, 'academic_cmap'):
            # & Academic-friendly colormap that works well in print (blues to reds)
            academic_colors = [(0.0, 0.2, 0.5), (0.2, 0.5, 0.8), (0.5, 0.8, 0.9), 
                            (0.9, 0.9, 0.5), (0.9, 0.6, 0.2), (0.8, 0.2, 0.0)]
            self.academic_cmap = LinearSegmentedColormap.from_list("academic", academic_colors)

        # & Create figure with publication styling
        plt.style.use('seaborn-v0_8-white')

        # & Create a higher resolution figure with better aspect ratio for publications
        fig, ax = plt.subplots(figsize=figsize, dpi=300)

        # & Use a white background for publication quality
        ax.set_facecolor('white')
        fig.patch.set_facecolor('white')

        # & Determine ranges if not provided
        if x_range is None or y_range is None:
            # Calculate range covering all components (mean ± 3*std for each dimension)
            means = gmm.means
            stds = np.array([np.sqrt(np.diag(cov)) for cov in gmm.covs])
            
            if x_range is None:
                x_min = min(means[:, 0] - 3 * stds[:, 0])
                x_max = max(means[:, 0] + 3 * stds[:, 0])
                x_range = (x_min, x_max)
            
            if y_range is None:
                y_min = min(means[:, 1] - 3 * stds[:, 1])
                y_max = max(means[:, 1] + 3 * stds[:, 1])
                y_range = (y_min, y_max)

        # & Generate grid for density evaluation
        x = np.linspace(x_range[0], x_range[1], n_points)
        y = np.linspace(y_range[0], y_range[1], n_points)
        X, Y = np.meshgrid(x, y)

        # & Prepare grid points for GMM evaluation
        XY = np.column_stack([X.ravel(), Y.ravel()])

        # & Calculate PDF values
        log_probs = gmm.log_prob(XY)
        probs = np.exp(log_probs).reshape(X.shape)

        # & Use coolwarm colormap for density visualization
        plot_cmap = plt.cm.coolwarm
        
        # & Create contour plot optimized for publications
        contour = ax.contourf(X, Y, probs, levels=15, cmap=plot_cmap, alpha=0.9)
        
        # & Draw samples with alternative markers in a high-contrast color
        samples = gmm.sample(n_samples)
        
        # & Use alternative markers less likely to trigger trypophobia
        # & Select a random subset of samples (about 10%)
        indices = np.random.choice(n_samples, size=int(n_samples * 0.1), replace=False)
        subset = samples[indices]
        
        # & Use plus markers with a high-contrast color against blue background
        marker_color = '#FF8C00'  # Bright orange - contrasts well with blue
        scatter_plus = ax.plot(subset[:, 0], subset[:, 1], '+', 
                            color=marker_color, alpha=0.8, markersize=4, 
                            markeredgewidth=0.8, label='Samples')
        
        # & Plot component ellipses with publication-quality styling
        if show_components:
            for i in range(gmm.n_components):
                # & Get component parameters
                mean = gmm.means[i]
                cov = gmm.covs[i]
                
                # & Calculate eigenvalues and eigenvectors for the ellipse
                eigvals, eigvecs = np.linalg.eigh(cov)
                
                # & Sort eigenvalues and eigenvectors
                idx = eigvals.argsort()[::-1]
                eigvals = eigvals[idx]
                eigvecs = eigvecs[:, idx]
                
                # & Ellipse parameters (95% confidence interval)
                angle = np.degrees(np.arctan2(eigvecs[1, 0], eigvecs[0, 0]))
                width, height = 2 * np.sqrt(5.991 * eigvals)  # 95% confidence
                
                # & Use dark gray for better contrast in publication
                ellipse = Ellipse(xy=mean, width=width, height=height, angle=angle,
                                edgecolor='#333333', facecolor='none', 
                                linewidth=1.0, alpha=0.8, linestyle='-')
                
                ax.add_patch(ellipse)
                
                # & Clean, minimal component labeling for academic papers
                ax.text(mean[0], mean[1], f"{i+1}", fontsize=10, ha='center', va='center',
                    color='black', fontweight='bold',
                    bbox=dict(boxstyle="circle", fc='white', ec="#333333", 
                                alpha=0.9, pad=0.2))
                
        # & Add a publication-quality colorbar
        cax = inset_axes(ax, width="3%", height="40%", loc="lower right", 
                        bbox_to_anchor=(0.05, 0, 1, 1),
                        bbox_transform=ax.transAxes, borderpad=0)
        cbar = plt.colorbar(contour, cax=cax)
        cbar.set_label('Density', rotation=270, labelpad=15, fontsize=10)
        cbar.ax.tick_params(labelsize=8)
        
        # & Add labels and title with publication styling
        ax.set_xlabel('$x_1$', fontsize=12)
        ax.set_ylabel('$x_2$', fontsize=12)
        if title:
            ax.set_title(title, fontsize=13, pad=10)
        else:
            ax.set_title(f'GMM with {gmm.n_components} Components', fontsize=13, pad=10)
            
        # & Cleaner grid for academic papers
        ax.grid(False)
        
        # & Better tick parameters for papers
        ax.tick_params(axis='both', which='major', labelsize=10, 
                    length=4, width=0.8, direction='out')
        
        #& Remove top and right spines for cleaner academic style
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_linewidth(0.8)
        ax.spines['bottom'].set_linewidth(0.8)
        
        # & Add legend for sample markers
        ax.legend(loc='upper right', frameon=True, framealpha=0.9, fontsize=9)
        
        # & Set equal aspect ratio for better visualization
        ax.set_aspect('equal')
        
        # & Adjust layout
        plt.tight_layout()
        
        # & Save with higher DPI for publication quality
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        return fig, ax
    

    def visualize_3d(self, gmm, x_range=None, y_range=None, z_range=None, n_points=50, 
                 n_samples=1000, show_components=True, title=None, figsize=None, 
                 alpha_samples=0.3, alpha_surface=0.2, std_scale=2.0, 
                 show_scatter=True, cmap=None, custom_colors=None, save_path=None):
        """
            Visualize a 3D Gaussian Mixture Model with ellipsoids and scatter plots.
    
            Args:
                gmm: The GMM distribution to visualize
                x_range (tuple, optional): Range for x-axis as (min, max)
                y_range (tuple, optional): Range for y-axis as (min, max)
                z_range (tuple, optional): Range for z-axis as (min, max)
                n_points (int): Number of points for density grid (per dimension)
                n_samples (int): Number of samples to draw
                show_components (bool): Whether to show ellipsoids for individual components
                title (str, optional): Plot title
                figsize (tuple, optional): Figure size
                alpha_samples (float): Transparency for sample points
                alpha_surface (float): Transparency for ellipsoid surfaces
                std_scale (float): Scale factor for standard deviation ellipsoids (1.0 = 1 std)
                show_scatter (bool): Whether to show scattered sample points
                cmap (str, optional): Colormap for component visualization
                save_path (str, optional): Path to save the figure
                
            Returns:
                fig, ax: Matplotlib figure and axis objects
        """
        # & Validate input dimensionality
        if gmm.dim != 3:
            raise ValueError(f"This method only visualizes 3D distributions, but got {gmm.dim}D")
        
        # & Use instance defaults if not specified
        if cmap is None:
            cmap = self.cmap
        if figsize is None:
            figsize = self._default_figsizes['3d']

        # & Create figure with publication styling
        plt.style.use('seaborn-v0_8-white')
        fig = plt.figure(figsize=figsize, dpi=300)
        ax = fig.add_subplot(111, projection='3d')

        # & Use a white background for publication quality
        ax.set_facecolor('white')
        fig.patch.set_facecolor('white')

        # & Determine ranges if not provided
        if x_range is None or y_range is None or z_range is None:
            # & Calculate range covering all components (mean ± 3*std for each dimension)
            means = gmm.means
            stds = np.array([np.sqrt(np.diag(cov)) for cov in gmm.covs])
            
            if x_range is None:
                x_min = min(means[:, 0] - 3 * stds[:, 0])
                x_max = max(means[:, 0] + 3 * stds[:, 0])
                x_range = (x_min, x_max)
            
            if y_range is None:
                y_min = min(means[:, 1] - 3 * stds[:, 1])
                y_max = max(means[:, 1] + 3 * stds[:, 1])
                y_range = (y_min, y_max)
                
            if z_range is None:
                z_min = min(means[:, 2] - 3 * stds[:, 2])
                z_max = max(means[:, 2] + 3 * stds[:, 2])
                z_range = (z_min, z_max)

        # & Draw samples for scatter plot
        if show_scatter:
            samples = gmm.sample(n_samples)
            
            # & Find component assignment for each sample
            component_probs = np.zeros((n_samples, gmm.n_components))
            for i in range(gmm.n_components):
                component_probs[:, i] = gmm.weights[i] * np.exp(gmm.components[i].logpdf(samples))
            
            # & Normalize probabilities
            component_probs = component_probs / component_probs.sum(axis=1, keepdims=True)
            
            # & Assign each sample to most likely component
            component_assignments = np.argmax(component_probs, axis=1)

            # & Create a dummy scatter for the legend
            # & This scatter won't be visible but will be used for legend creation
            scatter_legend = ax.scatter([], [], [], 
                                      marker='+', color='black', s=40, linewidths=0.8,
                                      label='Samples')
            
            # & Plot samples with color based on component assignment
            # & Generate colors for components
            if custom_colors is not None:
                # Use provided custom colors
                if len(custom_colors) < gmm.n_components:
                    raise ValueError(f"Not enough custom colors provided. Need {gmm.n_components}, got {len(custom_colors)}")
                colors = custom_colors
            else:
                # Use colormap-generated colors
                colors = cm.get_cmap(cmap)(np.linspace(0, 1, gmm.n_components))
            for i in range(gmm.n_components):
                idx = component_assignments == i
                ax.scatter(samples[idx, 0], samples[idx, 1], samples[idx, 2], 
                        alpha=alpha_samples, color=colors[i], 
                        marker='+',  # Change to plus marker
                        s=40,        # Adjusted size for better visibility in 3D
                        linewidths=0.8)  # Control line thickness for the plus marker
                
        # & Plot component ellipsoids if requested
        if show_components and gmm.n_components >= 1:
            # & Plot ellipsoids for each component
            for i in range(gmm.n_components):
                # & Get component parameters
                mean = gmm.means[i]
                cov = gmm.covs[i]
                
                # & Calculate eigenvalues and eigenvectors for the ellipsoid
                eigvals, eigvecs = np.linalg.eigh(cov)
                
                # & Generate ellipsoid surface points
                # & Sphere coordinates with radius 1
                u = np.linspace(0, 2 * np.pi, 20)
                v = np.linspace(0, np.pi, 20)
                x = np.outer(np.cos(u), np.sin(v))
                y = np.outer(np.sin(u), np.sin(v))
                z = np.outer(np.ones_like(u), np.cos(v))
                
                # & Convert the sphere to an ellipsoid with appropriate axes
                ellipsoid_x = mean[0] + std_scale * np.sqrt(eigvals[0]) * (
                    eigvecs[0, 0] * x + eigvecs[0, 1] * y + eigvecs[0, 2] * z)
                ellipsoid_y = mean[1] + std_scale * np.sqrt(eigvals[1]) * (
                    eigvecs[1, 0] * x + eigvecs[1, 1] * y + eigvecs[1, 2] * z)
                ellipsoid_z = mean[2] + std_scale * np.sqrt(eigvals[2]) * (
                    eigvecs[2, 0] * x + eigvecs[2, 1] * y + eigvecs[2, 2] * z)
                
                # & Plot the ellipsoid surface
                surf = ax.plot_surface(
                    ellipsoid_x, ellipsoid_y, ellipsoid_z,
                    rstride=1, cstride=1, color=colors[i], alpha=alpha_surface,
                    linewidth=0, antialiased=True)
                
                # & Add a small sphere at the component center
                ax.scatter([mean[0]], [mean[1]], [mean[2]], 
                        color=colors[i], s=50, edgecolor='black')
                
                # & Add a label for each component
                ax.text(mean[0], mean[1], mean[2], f"{i+1}", 
                    color='black', fontsize=10, ha='center', va='center')
                
        # & Set axis labels and limits with LaTeX formatting for publication style
        ax.set_xlabel('$x_1$', fontsize=12)
        ax.set_ylabel('$x_2$', fontsize=12)
        ax.set_zlabel('$x_3$', fontsize=12)

        # & Set axis limits
        ax.set_xlim(x_range)
        ax.set_ylim(y_range)
        ax.set_zlim(z_range)

        # & Clean up ticks and grid for publication
        ax.tick_params(axis='both', which='major', labelsize=10, 
                    direction='out', length=4, width=0.8)
        ax.grid(False)

        # & Apply title if provided
        # if title:
        #     ax.set_title(title, fontsize=13, pad=10)
        # else:
        #     ax.set_title(f'3D GMM with {gmm.n_components} Components', fontsize=13, pad=10)

        # & Add legend for sample markers
        if show_scatter:
            # & Add a legend in the upper right corner with clean styling
            ax.legend(loc='upper right', frameon=True, framealpha=0.9, fontsize=9)
        
        # & Set an optimal default viewing angle
        ax.view_init(elev=30, azim=30)
        
        # & Adjust layout
        plt.tight_layout()
        
        # & Save the figure if a path is provided
        if save_path:
            plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        
        return fig, ax
    

    def visualize_pairplot(self, gmm, vars=None, n_samples=1000, figsize=None, 
                      grid_size=None, diag_kind='kde', off_diag_kind='scatter', 
                      cmap=None, alpha=0.7, show_components=True, title=None, 
                      save_path=None):
        """
            Visualize a high-dimensional GMM using pairwise scatter plots.
            
            Args:
                gmm: The GMM distribution to visualize
                vars (list, optional): List of dimensions to include in the plot.
                    If None, all dimensions are used.
                n_samples (int): Number of samples to draw from the GMM
                figsize (tuple, optional): Figure size
                grid_size (tuple, optional): Tuple (rows, cols) for custom grid layout
                diag_kind (str): Kind of plot for the diagonal ('hist', 'kde')
                off_diag_kind (str): Kind of plot for the off-diagonal ('scatter')
                cmap (str, optional): Colormap for component visualization
                alpha (float): Transparency for scatter points
                show_components (bool): Whether to highlight different components with colors
                title (str, optional): Plot title
                save_path (str, optional): Path to save the figure
                
            Returns:
                fig, axes: Matplotlib figure and axes objects
        """
        # & Use instance defaults if not specified
        if cmap is None:
            cmap = self.cmap
        if figsize is None:
            figsize = self._default_figsizes['high_dim']
        
        # & Determine which dimensions to include
        if vars is None:
            n_dims = gmm.dim
            vars = list(range(n_dims))
        else:
            n_dims = len(vars)
            
        if n_dims < 2:
            raise ValueError("Pairplot requires at least 2 dimensions")
        
        # & Create a custom grid layout if specified
        if grid_size is not None:
            rows, cols = grid_size
            if rows * cols < n_dims * n_dims:
                raise ValueError(f"Grid size {grid_size} too small for {n_dims} dimensions")
        else:
            rows, cols = n_dims, n_dims
            
        # & Generate samples from the GMM
        samples = gmm.sample(n_samples)
        
        # & Find component assignment for each sample
        component_probs = np.zeros((n_samples, gmm.n_components))
        for i in range(gmm.n_components):
            component_probs[:, i] = gmm.weights[i] * np.exp(gmm.components[i].logpdf(samples))
            
        # & Normalize probabilities
        component_probs = component_probs / component_probs.sum(axis=1, keepdims=True)
        
        # & Assign each sample to most likely component
        component_assignments = np.argmax(component_probs, axis=1)
        
        # & Set up the figure and axes
        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        
        # & Generate colors for components
        colors = cm.get_cmap(cmap)(np.linspace(0, 1, gmm.n_components))
        
        # & Plot each pair of dimensions
        for i, dim_i in enumerate(vars):
            for j, dim_j in enumerate(vars):
                # Get the correct axis
                if rows == 1 and cols == 1:
                    ax = axes
                elif rows == 1:
                    ax = axes[j]
                elif cols == 1:
                    ax = axes[i]
                else:
                    ax = axes[i, j]
                    
                # & Diagonal plots: show marginal distribution
                if i == j:
                    # Extract this dimension
                    dim_data = samples[:, dim_i]
                    
                    if diag_kind == 'hist':
                        # & Basic histogram of the data
                        ax.hist(dim_data, bins=30, alpha=alpha, color='gray', 
                            density=True)
                        
                        # & Add KDE for each component if requested
                        if show_components and gmm.n_components > 1:
                            x = np.linspace(min(dim_data), max(dim_data), 1000)
                            for k in range(gmm.n_components):
                                # Get samples for this component
                                component_data = dim_data[component_assignments == k]
                                if len(component_data) > 0:
                                    # Plot KDE
                                    sns.kdeplot(component_data, ax=ax, color=colors[k], 
                                            label=f'Component {k+1}' if i == 0 and j == 0 else None)
                    
                    elif diag_kind == 'kde':
                        # & Show KDE for the dimension
                        sns.kdeplot(dim_data, ax=ax, color='black', linewidth=2, 
                                label='Overall' if i == 0 and j == 0 else None)
                        
                        # & Add KDE for each component if requested
                        if show_components and gmm.n_components > 1:
                            for k in range(gmm.n_components):
                                # Get samples for this component
                                component_data = dim_data[component_assignments == k]
                                if len(component_data) > 0:
                                    # Plot KDE
                                    sns.kdeplot(component_data, ax=ax, color=colors[k], 
                                            label=f'Component {k+1}' if i == 0 and j == 0 else None)
                    
                    # & Clean up diagonal plots
                    ax.set_yticks([])
                    if i < n_dims - 1:
                        ax.set_xticks([])
                    else:
                        ax.set_xlabel(f'Dimension {dim_i+1}')
                        
                # & Off-diagonal plots: show pairwise relationships
                else:
                    # & Extract the two dimensions
                    x_data = samples[:, dim_j]
                    y_data = samples[:, dim_i]
                    
                    if off_diag_kind == 'scatter':
                        # & Plot each component with a different color if requested
                        if show_components and gmm.n_components > 1:
                            for k in range(gmm.n_components):
                                mask = component_assignments == k
                                ax.scatter(x_data[mask], y_data[mask], alpha=alpha, 
                                        color=colors[k], s=10, 
                                        label=f'Component {k+1}' if i == n_dims-1 and j == 0 else None)
                        else:
                            # & Plot all samples with the same color
                            ax.scatter(x_data, y_data, alpha=alpha, color='gray', s=10)
                        
                        # & Draw component ellipses if requested
                        if show_components and gmm.n_components > 1:
                            for k in range(gmm.n_components):
                                # Get mean and covariance for these two dimensions
                                mean = np.array([gmm.means[k, dim_j], gmm.means[k, dim_i]])
                                cov = np.array([
                                    [gmm.covs[k, dim_j, dim_j], gmm.covs[k, dim_j, dim_i]],
                                    [gmm.covs[k, dim_i, dim_j], gmm.covs[k, dim_i, dim_i]]
                                ])
                                
                                # & Calculate eigenvalues and eigenvectors for the ellipse
                                eigvals, eigvecs = np.linalg.eigh(cov)
                                
                                # & Sort eigenvalues and eigenvectors
                                idx = eigvals.argsort()[::-1]
                                eigvals = eigvals[idx]
                                eigvecs = eigvecs[:, idx]
                                
                                # & Ellipse parameters (95% confidence interval)
                                angle = np.degrees(np.arctan2(eigvecs[1, 0], eigvecs[0, 0]))
                                width, height = 2 * np.sqrt(5.991 * eigvals)  # 95% confidence
                                
                                # & Draw the ellipse
                                ellipse = Ellipse(xy=mean, width=width, height=height, angle=angle,
                                            edgecolor=colors[k], facecolor='none', 
                                            linewidth=1.5, alpha=0.8)
                                ax.add_patch(ellipse)
                    
                    # & Clean up off-diagonal plots
                    if i < n_dims - 1:
                        ax.set_xticks([])
                    else:
                        ax.set_xlabel(f'Dimension {dim_j+1}')
                        
                    if j > 0:
                        ax.set_yticks([])
                    else:
                        ax.set_ylabel(f'Dimension {dim_i+1}')
        
        # & Add a legend to the first diagonal plot
        if show_components and gmm.n_components > 1:
            handles, labels = axes[0, 0].get_legend_handles_labels()
            if handles:
                fig.legend(handles, labels, loc='upper right', frameon=True, 
                        framealpha=0.9, fontsize=10)
        
        # & Add a title if provided
        if title:
            fig.suptitle(title, fontsize=16)
        else:
            fig.suptitle(f'Pairplot of {n_dims}D GMM with {gmm.n_components} Components', 
                    fontsize=16)
        
        # & Adjust layout
        plt.tight_layout()
        plt.subplots_adjust(top=0.9)
        
        # & Save if requested
        if save_path:
            plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        
        return fig, axes
    

    def visualize_high_dim(self, gmm, method='umap', n_components=2, n_samples=2000,
                   show_components=True, perplexity=30, n_neighbors=15, min_dist=0.1,
                   random_state=None, title=None, figsize=None, cmap=None, 
                   alpha_samples=0.7, show_density=True, density_levels=15,
                   custom_colors=None, save_path=None):
        """
            Visualize a high-dimensional GMM using dimensionality reduction techniques.
            
            Args:
                gmm: The GMM distribution to visualize
                method (str): Dimensionality reduction method: 'umap' or 'pca'
                n_components (int): Number of components for the projection (2 or 3)
                n_samples (int): Number of samples to draw from the GMM
                show_components (bool): Whether to show component ellipses/clusters
                perplexity (float): Perplexity parameter for t-SNE (if used)
                n_neighbors (int): Number of neighbors for UMAP (if used)
                min_dist (float): Minimum distance for UMAP (if used)
                random_state (int): Random seed for reproducibility
                title (str, optional): Plot title
                figsize (tuple, optional): Figure size
                cmap (str, optional): Colormap for component visualization
                alpha_samples (float): Transparency for scatter points
                show_density (bool): Whether to show density contours in 2D plots
                density_levels (int): Number of contour levels for density plots
                custom_colors (list): Custom colors for components
                save_path (str, optional): Path to save the figure
                
            Returns:
                fig, ax: Matplotlib figure and axis objects
        """
        # & Validate the method
        if method not in ['umap', 'pca']:
            raise ValueError(f"Unknown dimensionality reduction method: {method}. Use 'umap' or 'pca'.")
        
        # & Validate the number of components
        if n_components not in [2, 3]:
            raise ValueError(f"n_components must be 2 or 3, got {n_components}")
        
        # & Use instance defaults if not specified
        if cmap is None:
            cmap = self.cmap
        if figsize is None:
            figsize = self._default_figsizes['high_dim']
        if random_state is None:
            random_state = np.random.randint(0, 10000)
        
        # & Generate samples from the GMM
        samples = gmm.sample(n_samples)
        
        # & Find component assignment for each sample
        component_probs = np.zeros((n_samples, gmm.n_components))
        for i in range(gmm.n_components):
            component_probs[:, i] = gmm.weights[i] * np.exp(gmm.components[i].logpdf(samples))
            
        # & Normalize probabilities
        component_probs = component_probs / component_probs.sum(axis=1, keepdims=True)
        
        # & Assign each sample to most likely component
        component_assignments = np.argmax(component_probs, axis=1)
        
        # & Generate colors for components
        if custom_colors is not None:
            # & Use provided custom colors
            if len(custom_colors) < gmm.n_components:
                raise ValueError(f"Not enough custom colors provided. Need {gmm.n_components}, got {len(custom_colors)}")
            colors = custom_colors
        else:
            # & Use colormap-generated colors
            colors = cm.get_cmap(cmap)(np.linspace(0, 1, gmm.n_components))
        
        # & Apply dimensionality reduction
        if method == 'pca':
            # & PCA is available in sklearn
            reducer = PCA(n_components=n_components, random_state=random_state)
            reduced_samples = reducer.fit_transform(samples)
            
            # & Project component means to the reduced space
            reduced_means = reducer.transform(gmm.means)
            
            # & Transform component covariances to the reduced space
            reduced_covs = []
            for i in range(gmm.n_components):
                cov = gmm.covs[i]
                # & Project covariance matrices using PCA's components
                reduced_cov = reducer.components_ @ cov @ reducer.components_.T
                reduced_covs.append(reduced_cov)
            reduced_covs = np.array(reduced_covs)
            
        elif method == 'umap':
            # & UMAP for dimensionality reduction
            reducer = umap.UMAP(n_components=n_components, 
                            n_neighbors=n_neighbors,
                            min_dist=min_dist,
                            random_state=random_state)
            reduced_samples = reducer.fit_transform(samples)
            
            # & For UMAP, we can't directly project covariances
            # & We'll estimate them from the reduced samples
            reduced_means = np.zeros((gmm.n_components, n_components))
            reduced_covs = np.zeros((gmm.n_components, n_components, n_components))
            
            for i in range(gmm.n_components):
                mask = component_assignments == i
                if np.sum(mask) > n_components:  # Need more points than dimensions
                    component_samples = reduced_samples[mask]
                    reduced_means[i] = np.mean(component_samples, axis=0)
                    reduced_covs[i] = np.cov(component_samples, rowvar=False)
                else:
                    # Not enough samples for this component, use identity covariance
                    reduced_means[i] = np.mean(reduced_samples, axis=0)
                    reduced_covs[i] = np.eye(n_components)
        
        # & Create the plot based on dimensionality
        if n_components == 2:
            # & Prepare for 2D visualization with publication style
            plt.style.use('seaborn-v0_8-white')
            fig, ax = plt.subplots(figsize=figsize, dpi=300)
            
            # & Use a white background for publication quality
            ax.set_facecolor('white')
            fig.patch.set_facecolor('white')
            
            # & Plot each component's samples with assigned colors
            for i in range(gmm.n_components):
                mask = component_assignments == i
                ax.scatter(reduced_samples[mask, 0], reduced_samples[mask, 1], 
                        color=colors[i], alpha=alpha_samples, s=15, 
                        label=f'Component {i+1}')
            
            # & Add density contours if requested
            if show_density and method == 'pca':  # Density only meaningful for PCA
                # & Define grid for density estimation
                x_min, x_max = reduced_samples[:, 0].min() - 1, reduced_samples[:, 0].max() + 1
                y_min, y_max = reduced_samples[:, 1].min() - 1, reduced_samples[:, 1].max() + 1
                xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
                                np.linspace(y_min, y_max, 100))
                positions = np.vstack([xx.ravel(), yy.ravel()]).T
                
                # & Create a simplified 2D GMM in the reduced space
                density = np.zeros(len(positions))
                for i in range(gmm.n_components):
                    # & Calculate density contribution from each component
                    from scipy.stats import multivariate_normal
                    mvn = multivariate_normal(reduced_means[i], reduced_covs[i])
                    density += gmm.weights[i] * mvn.pdf(positions)
                
                # & Reshape density for contour plot
                density = density.reshape(xx.shape)
                
                # & Plot contours with a subdued colormap
                contour = ax.contour(xx, yy, density, levels=density_levels, 
                                cmap='Blues', alpha=0.6, linewidths=0.5)
            
            # & Show component ellipses if requested
            if show_components:
                for i in range(gmm.n_components):
                    # & Get reduced mean and covariance
                    mean = reduced_means[i]
                    cov = reduced_covs[i]
                    
                    # & Calculate eigenvalues and eigenvectors for the ellipse
                    eigvals, eigvecs = np.linalg.eigh(cov)
                    
                    # & Sort eigenvalues and eigenvectors
                    idx = eigvals.argsort()[::-1]
                    eigvals = eigvals[idx]
                    eigvecs = eigvecs[:, idx]
                    
                    # & Ellipse parameters (95% confidence interval)
                    angle = np.degrees(np.arctan2(eigvecs[1, 0], eigvecs[0, 0]))
                    width, height = 2 * np.sqrt(5.991 * eigvals)  # 95% confidence
                    
                    # & Draw the ellipse - handle potential numerical instability
                    if np.isfinite(width) and np.isfinite(height) and width > 0 and height > 0:
                        ellipse = Ellipse(xy=mean, width=width, height=height, angle=angle,
                                    edgecolor=colors[i], facecolor='none', 
                                    linewidth=1.5, alpha=0.8)
                        ax.add_patch(ellipse)
                    
                    # & Label components
                    ax.text(mean[0], mean[1], f"{i+1}", fontsize=10, ha='center', va='center',
                        color='black', fontweight='bold',
                        bbox=dict(boxstyle="circle", fc='white', ec=colors[i], 
                                alpha=0.9, pad=0.2))
            
            # & Label axes with method information
            if method == 'pca':
                ax.set_xlabel(f'Principal Component 1', fontsize=12)
                ax.set_ylabel(f'Principal Component 2', fontsize=12)
            else:  # UMAP
                ax.set_xlabel(f'UMAP Dimension 1', fontsize=12)
                ax.set_ylabel(f'UMAP Dimension 2', fontsize=12)
            
            # & Add title
            if title:
                ax.set_title(title, fontsize=13, pad=10)
            else:
                ax.set_title(f'{method.upper()} Projection of {gmm.dim}D GMM with {gmm.n_components} Components', 
                        fontsize=13, pad=10)
            
            # & Publication-quality styling
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['left'].set_linewidth(0.8)
            ax.spines['bottom'].set_linewidth(0.8)
            ax.tick_params(axis='both', which='major', labelsize=10, direction='out')
            ax.grid(False)
            
            # & Add legend
            if gmm.n_components <= 10:  # Only show legend if not too many components
                ax.legend(loc='best', frameon=True, framealpha=0.9, fontsize=9)
            
        elif n_components == 3:
            # & 3D visualization
            plt.style.use('seaborn-v0_8-white')
            fig = plt.figure(figsize=figsize, dpi=300)
            ax = fig.add_subplot(111, projection='3d')
            
            # & Use a white background for publication quality
            ax.set_facecolor('white')
            fig.patch.set_facecolor('white')
            
            # & Plot each component's samples
            for i in range(gmm.n_components):
                mask = component_assignments == i
                ax.scatter(reduced_samples[mask, 0], reduced_samples[mask, 1], reduced_samples[mask, 2],
                        color=colors[i], alpha=alpha_samples, s=15, label=f'Component {i+1}')
            
            # & Show component centers
            if show_components:
                for i in range(gmm.n_components):
                    mean = reduced_means[i]
                    ax.scatter([mean[0]], [mean[1]], [mean[2]], 
                            color=colors[i], s=100, edgecolor='black', alpha=1.0)
                    
                    # & Label the component
                    ax.text(mean[0], mean[1], mean[2], f"{i+1}", 
                        color='black', fontsize=10, ha='center', va='center')
            
            # & Label axes with method information
            if method == 'pca':
                ax.set_xlabel(f'Principal Component 1', fontsize=12)
                ax.set_ylabel(f'Principal Component 2', fontsize=12)
                ax.set_zlabel(f'Principal Component 3', fontsize=12)
            else:  # UMAP
                ax.set_xlabel(f'UMAP Dimension 1', fontsize=12)
                ax.set_ylabel(f'UMAP Dimension 2', fontsize=12)
                ax.set_zlabel(f'UMAP Dimension 3', fontsize=12)
            
            # & Add title
            if title:
                ax.set_title(title, fontsize=13, pad=10)
            else:
                ax.set_title(f'{method.upper()} Projection of {gmm.dim}D GMM with {gmm.n_components} Components', 
                        fontsize=13, pad=10)
            
            # & Better tick parameters
            ax.tick_params(axis='both', which='major', labelsize=10, direction='out')
            ax.grid(False)
            
            # & Set optimal viewing angle
            ax.view_init(elev=30, azim=30)
            
            # & Add legend
            if gmm.n_components <= 10:  # Only show legend if not too many components
                ax.legend(loc='best', frameon=True, framealpha=0.9, fontsize=9)
        
        # & Adjust layout
        plt.tight_layout()
        
        # & Save the figure if requested
        if save_path:
            plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        
        return fig, ax
