"""
BW-DAM vs Euclidean DAM: 2D Visualization.

This experiment:
1. Samples 2D Gaussian distributions with enforced separation (Theorem 1)
2. Compares retrieval dynamics of BW-DAM vs Euclidean DAM
3. Visualizes the trajectories and fixed points
4. Demonstrates cases where BW-DAM succeeds but Euclidean DAM fails

This provides intuition for why Wasserstein geometry is advantageous for
retrieval in the space of Gaussian distributions.
"""

from dataclasses import dataclass
from typing import List, Tuple, Optional

import numpy as np
import scipy.linalg as la
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse


# =============================================================================
# Configuration
# =============================================================================

@dataclass
class VisualizationConfig:
    """Configuration for 2D visualization experiment."""
    
    # Dimension (fixed at 2 for visualization)
    dim: int = 2
    
    # Eigenvalue bounds
    lambda_min: float = 0.8
    lambda_max: float = 1.0
    
    # Number of stored Gaussians
    num_patterns: int = 5
    
    # Temperature parameter
    beta: float = 2.0
    
    # Random seeds
    seed_gaussians: int = 42
    seed_perturbation: int = 123
    
    # Dynamics parameters
    max_iters: int = 50
    convergence_tol: float = 1e-10
    fixed_point_tol: float = 1e-14
    
    @property
    def condition_number(self) -> float:
        return self.lambda_max / self.lambda_min
    
    @property
    def contractive_radius(self) -> float:
        return np.sqrt(self.lambda_min)
    
    @property
    def min_mean_separation(self) -> float:
        log_kappa = np.log(self.condition_number)
        return np.sqrt(2 * self.dim * self.lambda_max * (log_kappa + 2))
    
    @property
    def w2_perturbation(self) -> float:
        return self.contractive_radius


# =============================================================================
# Matrix Utilities
# =============================================================================

def matrix_sqrt(A: np.ndarray) -> np.ndarray:
    """Compute matrix square root."""
    eigvals, eigvecs = la.eigh(A)
    eigvals = np.maximum(eigvals, 1e-10)
    return eigvecs @ np.diag(np.sqrt(eigvals)) @ eigvecs.T


def matrix_sqrt_inv(A: np.ndarray) -> np.ndarray:
    """Compute inverse matrix square root."""
    eigvals, eigvecs = la.eigh(A)
    eigvals = np.maximum(eigvals, 1e-10)
    return eigvecs @ np.diag(1.0 / np.sqrt(eigvals)) @ eigvecs.T


def is_positive_definite(Sigma: np.ndarray) -> bool:
    """Check if matrix is positive definite."""
    return np.all(la.eigvalsh(Sigma) > 0)


def sample_haar_orthogonal(d: int, rng: np.random.Generator) -> np.ndarray:
    """Sample orthogonal matrix from Haar measure."""
    Z = rng.standard_normal((d, d))
    Q, R = np.linalg.qr(Z)
    return Q @ np.diag(np.sign(np.diag(R)))


# =============================================================================
# Wasserstein Distance
# =============================================================================

def w2_squared(
    mu1: np.ndarray,
    Sigma1: np.ndarray,
    mu2: np.ndarray,
    Sigma2: np.ndarray,
) -> float:
    """Squared W₂ distance between Gaussians."""
    mean_diff_sq = np.sum((mu1 - mu2) ** 2)
    Sigma1_sqrt = matrix_sqrt(Sigma1)
    inner = Sigma1_sqrt @ Sigma2 @ Sigma1_sqrt
    inner_sqrt = matrix_sqrt(inner)
    bures_term = np.trace(Sigma1) + np.trace(Sigma2) - 2 * np.trace(inner_sqrt)
    return mean_diff_sq + max(0, bures_term)


def w2_distance(
    mu1: np.ndarray,
    Sigma1: np.ndarray,
    mu2: np.ndarray,
    Sigma2: np.ndarray,
) -> float:
    """W₂ distance between Gaussians."""
    return np.sqrt(w2_squared(mu1, Sigma1, mu2, Sigma2))


# =============================================================================
# Gaussian Sampling
# =============================================================================

def sample_separated_gaussians(
    config: VisualizationConfig,
    rng: np.random.Generator,
    max_attempts: int = 1000,
) -> List[Tuple[np.ndarray, np.ndarray]]:
    """Sample Gaussians with enforced mean separation."""
    gaussians = []
    means = []
    mean_scale = config.min_mean_separation * np.sqrt(config.num_patterns)
    
    for i in range(config.num_patterns):
        for _ in range(max_attempts):
            mu = rng.uniform(-mean_scale, mean_scale, config.dim)
            if all(np.linalg.norm(mu - m) >= config.min_mean_separation for m in means):
                break
        else:
            raise RuntimeError(f"Could not find separated mean for Gaussian {i+1}")
        
        means.append(mu)
        eigenvalues = rng.uniform(config.lambda_min, config.lambda_max, config.dim)
        Q = sample_haar_orthogonal(config.dim, rng)
        Sigma = Q @ np.diag(eigenvalues) @ Q.T
        gaussians.append((mu, Sigma))
    
    return gaussians


# =============================================================================
# Vectorization
# =============================================================================

def gaussian_to_vector(mu: np.ndarray, Sigma: np.ndarray) -> np.ndarray:
    """Convert Gaussian to vector representation."""
    d = len(mu)
    upper_tri = Sigma[np.triu_indices(d)]
    return np.concatenate([mu, upper_tri])


def vector_to_gaussian(v: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Convert vector to Gaussian parameters."""
    length = len(v)
    d = int((-3 + np.sqrt(9 + 8 * length)) / 2)
    mu = v[:d]
    upper_tri = v[d:]
    Sigma = np.zeros((d, d))
    Sigma[np.triu_indices(d)] = upper_tri
    Sigma = Sigma + Sigma.T - np.diag(np.diag(Sigma))
    return mu, Sigma


# =============================================================================
# DAM Dynamics
# =============================================================================

def wasserstein_dam_step(
    m: np.ndarray,
    Omega: np.ndarray,
    stored_gaussians: List[Tuple[np.ndarray, np.ndarray]],
    beta: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """One step of BW-DAM dynamics."""
    D_sq = np.array([
        w2_squared(mu_i, Sigma_i, m, Omega)
        for mu_i, Sigma_i in stored_gaussians
    ])
    
    log_w = -beta * D_sq
    log_w = log_w - np.max(log_w)
    w = np.exp(log_w)
    w = w / np.sum(w)
    
    A_list = []
    for mu_i, Sigma_i in stored_gaussians:
        Sigma_i_sqrt = matrix_sqrt(Sigma_i)
        inner = Sigma_i_sqrt @ Omega @ Sigma_i_sqrt
        inner_sqrt_inv = matrix_sqrt_inv(inner)
        A_i = Sigma_i_sqrt @ inner_sqrt_inv @ Sigma_i_sqrt
        A_list.append(A_i)
    
    m_new = sum(w_i * mu_i for w_i, (mu_i, _) in zip(w, stored_gaussians))
    A_tilde = sum(w_i * A_i for w_i, A_i in zip(w, A_list))
    Omega_new = A_tilde @ Omega @ A_tilde.T
    Omega_new = (Omega_new + Omega_new.T) / 2
    
    return m_new, Omega_new, w


def euclidean_dam_step(
    xi: np.ndarray,
    X: np.ndarray,
    beta: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """One step of Euclidean DAM dynamics."""
    scores = beta * (X.T @ xi)
    scores = scores - np.max(scores)
    w = np.exp(scores)
    w = w / np.sum(w)
    return X @ w, w


def find_wasserstein_fixed_point(
    mu_init: np.ndarray,
    Sigma_init: np.ndarray,
    stored_gaussians: List[Tuple[np.ndarray, np.ndarray]],
    beta: float,
    max_iters: int = 1000,
    tol: float = 1e-14,
) -> Tuple[np.ndarray, np.ndarray, int]:
    """Find fixed point of BW-DAM."""
    m, Omega = mu_init.copy(), Sigma_init.copy()
    for t in range(max_iters):
        m_new, Omega_new, _ = wasserstein_dam_step(m, Omega, stored_gaussians, beta)
        if w2_distance(m, Omega, m_new, Omega_new) < tol:
            return m_new, Omega_new, t + 1
        m, Omega = m_new, Omega_new
    return m, Omega, max_iters


def find_euclidean_fixed_point(
    xi_init: np.ndarray,
    X: np.ndarray,
    beta: float,
    max_iters: int = 1000,
    tol: float = 1e-14,
) -> Tuple[np.ndarray, int]:
    """Find fixed point of Euclidean DAM."""
    xi = xi_init.copy()
    for t in range(max_iters):
        xi_new, _ = euclidean_dam_step(xi, X, beta)
        mu_new, Sigma_new = vector_to_gaussian(xi_new)
        if not is_positive_definite(Sigma_new):
            return xi, t + 1
        if np.linalg.norm(xi_new - xi) < tol:
            return xi_new, t + 1
        xi = xi_new
    return xi, max_iters


# =============================================================================
# Perturbation
# =============================================================================

def perturb_gaussian(
    mu_target: np.ndarray,
    Sigma_target: np.ndarray,
    w2_dist: float,
    rng: np.random.Generator,
) -> Tuple[np.ndarray, np.ndarray, float]:
    """Perturb Gaussian in mean direction only."""
    direction = rng.standard_normal(len(mu_target))
    direction = direction / np.linalg.norm(direction)
    mu_perturbed = mu_target + w2_dist * direction
    Sigma_perturbed = Sigma_target.copy()
    actual_w2 = w2_distance(mu_perturbed, Sigma_perturbed, mu_target, Sigma_target)
    return mu_perturbed, Sigma_perturbed, actual_w2


# =============================================================================
# Experiment
# =============================================================================

@dataclass
class ExperimentResult:
    """Results from a single retrieval experiment."""
    target_idx: int
    mu_target: np.ndarray
    Sigma_target: np.ndarray
    mu_init: np.ndarray
    Sigma_init: np.ndarray
    mu_fp_wass: np.ndarray
    Sigma_fp_wass: np.ndarray
    mu_fp_eucl: np.ndarray
    Sigma_fp_eucl: np.ndarray
    traj_wasserstein: List[Tuple[np.ndarray, np.ndarray]]
    traj_euclidean: List[np.ndarray]
    retrieved_wass: int
    retrieved_eucl: int
    w2_to_target_wass: float
    w2_to_target_eucl: float


def run_retrieval_experiment(
    config: VisualizationConfig,
    stored_gaussians: List[Tuple[np.ndarray, np.ndarray]],
    X: np.ndarray,
    target_idx: int,
) -> ExperimentResult:
    """Run retrieval for a single target pattern."""
    mu_target, Sigma_target = stored_gaussians[target_idx]
    
    rng = np.random.default_rng(config.seed_perturbation + target_idx)
    mu_init, Sigma_init, _ = perturb_gaussian(
        mu_target, Sigma_target, config.w2_perturbation, rng
    )
    xi_init = gaussian_to_vector(mu_init, Sigma_init)
    
    # Find fixed points
    mu_fp_wass, Sigma_fp_wass, _ = find_wasserstein_fixed_point(
        mu_init, Sigma_init, stored_gaussians, config.beta,
        max_iters=1000, tol=config.fixed_point_tol
    )
    xi_fp_eucl, _ = find_euclidean_fixed_point(
        xi_init, X, config.beta, max_iters=1000, tol=config.fixed_point_tol
    )
    mu_fp_eucl, Sigma_fp_eucl = vector_to_gaussian(xi_fp_eucl)
    
    # Run dynamics and record trajectories
    traj_wasserstein = [(mu_init.copy(), Sigma_init.copy())]
    m, Omega = mu_init.copy(), Sigma_init.copy()
    for _ in range(config.max_iters):
        m_new, Omega_new, _ = wasserstein_dam_step(m, Omega, stored_gaussians, config.beta)
        traj_wasserstein.append((m_new.copy(), Omega_new.copy()))
        if w2_distance(m_new, Omega_new, mu_fp_wass, Sigma_fp_wass) < config.convergence_tol:
            break
        m, Omega = m_new, Omega_new
    
    traj_euclidean = [xi_init.copy()]
    xi = xi_init.copy()
    for _ in range(config.max_iters):
        xi_new, _ = euclidean_dam_step(xi, X, config.beta)
        traj_euclidean.append(xi_new.copy())
        mu_new, Sigma_new = vector_to_gaussian(xi_new)
        if not is_positive_definite(Sigma_new):
            break
        if np.linalg.norm(xi_new - xi_fp_eucl) < config.convergence_tol:
            break
        xi = xi_new
    
    # Determine retrieved patterns
    w2_to_stored_wass = [
        w2_distance(mu_fp_wass, Sigma_fp_wass, mu_i, Sigma_i)
        for mu_i, Sigma_i in stored_gaussians
    ]
    retrieved_wass = int(np.argmin(w2_to_stored_wass))
    
    eucl_to_stored = [
        np.linalg.norm(xi_fp_eucl - gaussian_to_vector(mu_i, Sigma_i))
        for mu_i, Sigma_i in stored_gaussians
    ]
    retrieved_eucl = int(np.argmin(eucl_to_stored))
    
    return ExperimentResult(
        target_idx=target_idx,
        mu_target=mu_target,
        Sigma_target=Sigma_target,
        mu_init=mu_init,
        Sigma_init=Sigma_init,
        mu_fp_wass=mu_fp_wass,
        Sigma_fp_wass=Sigma_fp_wass,
        mu_fp_eucl=mu_fp_eucl,
        Sigma_fp_eucl=Sigma_fp_eucl,
        traj_wasserstein=traj_wasserstein,
        traj_euclidean=traj_euclidean,
        retrieved_wass=retrieved_wass,
        retrieved_eucl=retrieved_eucl,
        w2_to_target_wass=w2_distance(mu_fp_wass, Sigma_fp_wass, mu_target, Sigma_target),
        w2_to_target_eucl=w2_distance(mu_fp_eucl, Sigma_fp_eucl, mu_target, Sigma_target),
    )


# =============================================================================
# Visualization
# =============================================================================

PATTERN_COLORS = ["#3498DB", "#9B59B6", "#E67E22", "#1ABC9C", "#F1C40F"]
QUERY_COLOR = "#2ECC71"
FIXED_POINT_COLOR = "#C0392B"
SUCCESS_COLOR = "#27AE60"
FAILURE_COLOR = "#E74C3C"


def draw_ellipse(
    ax: plt.Axes,
    mu: np.ndarray,
    Sigma: np.ndarray,
    color: str,
    lw: float = 2,
    ls: str = "-",
    label: Optional[str] = None,
    n_std: float = 2.0,
    alpha: float = 1.0,
) -> None:
    """Draw Gaussian as confidence ellipse."""
    eigvals, eigvecs = la.eigh(Sigma)
    eigvals = np.maximum(eigvals, 1e-6)
    angle = np.degrees(np.arctan2(eigvecs[1, 0], eigvecs[0, 0]))
    width = 2 * n_std * np.sqrt(eigvals[0])
    height = 2 * n_std * np.sqrt(eigvals[1])
    ellipse = Ellipse(
        xy=mu, width=width, height=height, angle=angle,
        facecolor="none", edgecolor=color,
        linestyle=ls, linewidth=lw, label=label, alpha=alpha
    )
    ax.add_patch(ellipse)


def compute_axis_limits(
    gaussians: List[Tuple[np.ndarray, np.ndarray]],
    n_std: float = 2.0,
    margin: float = 0.5,
) -> Tuple[float, float, float, float]:
    """Compute axis limits to fit all Gaussians."""
    x_min, x_max = float("inf"), float("-inf")
    y_min, y_max = float("inf"), float("-inf")
    
    for mu, Sigma in gaussians:
        eigvals = la.eigvalsh(Sigma)
        max_radius = n_std * np.sqrt(max(eigvals))
        x_min = min(x_min, mu[0] - max_radius)
        x_max = max(x_max, mu[0] + max_radius)
        y_min = min(y_min, mu[1] - max_radius)
        y_max = max(y_max, mu[1] + max_radius)
    
    return x_min - margin, x_max + margin, y_min - margin, y_max + margin


def plot_comparison(
    results: List[ExperimentResult],
    stored_gaussians: List[Tuple[np.ndarray, np.ndarray]],
    output_path: str = "bwdam_vs_euclidean_2d.png",
) -> None:
    """Generate comparison visualization."""
    # Select interesting cases
    both_correct = [r for r in results if r.retrieved_wass == r.target_idx and r.retrieved_eucl == r.target_idx]
    only_wass = [r for r in results if r.retrieved_wass == r.target_idx and r.retrieved_eucl != r.target_idx]
    selected = both_correct[:2] + only_wass[:2]
    
    if len(selected) < 4:
        selected = results[:min(4, len(results))]
    
    x_min_g, x_max_g, y_min_g, y_max_g = compute_axis_limits(stored_gaussians)
    
    fig, axes = plt.subplots(len(selected), 3, figsize=(18, 5.5 * len(selected)))
    if len(selected) == 1:
        axes = axes.reshape(1, -1)
    
    for row, res in enumerate(selected):
        target_idx = res.target_idx
        ax1, ax2, ax3 = axes[row]
        wass_correct = res.retrieved_wass == res.target_idx
        eucl_correct = res.retrieved_eucl == res.target_idx
        
        # Column 1: Stored Gaussians
        for i, (mu, Sigma) in enumerate(stored_gaussians):
            if i == target_idx:
                draw_ellipse(ax1, mu, Sigma, PATTERN_COLORS[i], lw=3, label=f"$X_{i+1}$ (target)")
                ax1.plot(mu[0], mu[1], "o", color=PATTERN_COLORS[i], ms=12, mec="black", mew=2)
            else:
                draw_ellipse(ax1, mu, Sigma, PATTERN_COLORS[i], lw=1.5, ls="--", label=f"$X_{i+1}$", alpha=0.5)
                ax1.plot(mu[0], mu[1], "o", color=PATTERN_COLORS[i], ms=8, mec="black", mew=1, alpha=0.5)
        
        ax1.set_xlabel("$x$", fontsize=12)
        ax1.set_ylabel("$y$", fontsize=12)
        ax1.set_title(f"Stored Gaussians (Target: $X_{target_idx+1}$)", fontsize=11, fontweight="bold")
        ax1.legend(loc="best", fontsize=8)
        ax1.set_aspect("equal")
        ax1.grid(True, alpha=0.3)
        ax1.set_xlim(x_min_g, x_max_g)
        ax1.set_ylim(y_min_g, y_max_g)
        
        # Compute shared limits for dynamics plots
        dynamics_gaussians = [
            (res.mu_target, res.Sigma_target),
            (res.mu_init, res.Sigma_init),
            (res.mu_fp_wass, res.Sigma_fp_wass),
            (res.mu_fp_eucl, res.Sigma_fp_eucl),
        ]
        x_min_d, x_max_d, y_min_d, y_max_d = compute_axis_limits(dynamics_gaussians, margin=0.3)
        
        # Column 2: BW-DAM
        _plot_dynamics_panel(ax2, res, stored_gaussians, target_idx,
                            res.mu_fp_wass, res.Sigma_fp_wass, res.traj_wasserstein,
                            "BW-DAM", wass_correct, is_wasserstein=True)
        ax2.set_xlim(x_min_d, x_max_d)
        ax2.set_ylim(y_min_d, y_max_d)
        
        # Column 3: Euclidean DAM
        traj_eucl_gaussians = [vector_to_gaussian(v) for v in res.traj_euclidean]
        _plot_dynamics_panel(ax3, res, stored_gaussians, target_idx,
                            res.mu_fp_eucl, res.Sigma_fp_eucl, traj_eucl_gaussians,
                            "Euclidean DAM", eucl_correct, is_wasserstein=False)
        ax3.set_xlim(x_min_d, x_max_d)
        ax3.set_ylim(y_min_d, y_max_d)
    
    plt.tight_layout()
    plt.subplots_adjust(right=0.92, hspace=0.25)
    plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
    print(f"\nSaved: {output_path}")


def _plot_dynamics_panel(
    ax: plt.Axes,
    res: ExperimentResult,
    stored_gaussians: List[Tuple[np.ndarray, np.ndarray]],
    target_idx: int,
    mu_fp: np.ndarray,
    Sigma_fp: np.ndarray,
    trajectory: List[Tuple[np.ndarray, np.ndarray]],
    title: str,
    is_correct: bool,
    is_wasserstein: bool,
) -> None:
    """Plot a single dynamics panel."""
    # Non-target patterns (faded)
    for i, (mu, Sigma) in enumerate(stored_gaussians):
        if i != target_idx:
            draw_ellipse(ax, mu, Sigma, PATTERN_COLORS[i], lw=1, ls="--", alpha=0.3)
            ax.plot(mu[0], mu[1], "o", color=PATTERN_COLORS[i], ms=6, mec="black", mew=0.5, alpha=0.3)
    
    # Fixed point
    draw_ellipse(ax, mu_fp, Sigma_fp, FIXED_POINT_COLOR, lw=3, label=r"Fixed point $\xi^*$")
    ax.plot(mu_fp[0], mu_fp[1], "*", color=FIXED_POINT_COLOR, ms=18, mec="black", mew=1.5)
    
    # Trajectory
    if is_wasserstein:
        means = np.array([m for m, _ in trajectory])
    else:
        means = np.array([mu for mu, _ in trajectory])
    ax.plot(means[:, 0], means[:, 1], "-", color="gray", lw=2, alpha=0.6)
    ax.plot(means[1:-1, 0], means[1:-1, 1], "o", color="gray", ms=4, alpha=0.7)
    
    # Query
    draw_ellipse(ax, res.mu_init, res.Sigma_init, QUERY_COLOR, lw=2.5, label=r"Query $\xi$")
    ax.plot(res.mu_init[0], res.mu_init[1], "o", color=QUERY_COLOR, ms=10, mec="black", mew=1.5)
    
    # Target
    mu_t, Sigma_t = stored_gaussians[target_idx]
    draw_ellipse(ax, mu_t, Sigma_t, PATTERN_COLORS[target_idx], lw=2.5, ls="--", label=f"$X_{target_idx+1}$ (target)")
    ax.plot(mu_t[0], mu_t[1], "s", color=PATTERN_COLORS[target_idx], ms=12, mec="black", mew=1.5)
    
    ax.set_xlabel("$x$", fontsize=12)
    ax.set_ylabel("$y$", fontsize=12)
    ax.set_title(title, fontsize=11, fontweight="bold")
    ax.legend(loc="best", fontsize=8)
    ax.set_aspect("equal")
    ax.grid(True, alpha=0.3)
    
    # Success indicator
    mark = "✓" if is_correct else "✗"
    color = SUCCESS_COLOR if is_correct else FAILURE_COLOR
    ax.text(1.02, 0.5, mark, transform=ax.transAxes, fontsize=36,
            fontweight="bold", color=color, va="center", ha="left")


# =============================================================================
# Main
# =============================================================================

def main():
    """Run the visualization experiment."""
    config = VisualizationConfig()
    
    print("=" * 70)
    print("BW-DAM vs Euclidean DAM: 2D Visualization")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Number of patterns: {config.num_patterns}")
    print(f"  λ_min: {config.lambda_min}, λ_max: {config.lambda_max}")
    print(f"  Condition number κ: {config.condition_number:.4f}")
    print(f"  Temperature β: {config.beta}")
    print(f"  Min mean separation: {config.min_mean_separation:.4f}")
    
    # Sample Gaussians
    rng = np.random.default_rng(config.seed_gaussians)
    stored_gaussians = sample_separated_gaussians(config, rng)
    
    # Vectorize for Euclidean DAM
    stored_vectors = np.array([
        gaussian_to_vector(mu, Sigma) for mu, Sigma in stored_gaussians
    ])
    X = stored_vectors.T
    
    # Run experiments
    print("\nRunning retrieval experiments...")
    results = []
    for target_idx in range(config.num_patterns):
        res = run_retrieval_experiment(config, stored_gaussians, X, target_idx)
        results.append(res)
        wass_ok = "✓" if res.retrieved_wass == target_idx else "✗"
        eucl_ok = "✓" if res.retrieved_eucl == target_idx else "✗"
        print(f"  Target X_{target_idx+1}: BW-DAM {wass_ok}, Euclidean {eucl_ok}")
    
    # Summary
    wass_acc = sum(1 for r in results if r.retrieved_wass == r.target_idx)
    eucl_acc = sum(1 for r in results if r.retrieved_eucl == r.target_idx)
    print(f"\nAccuracy: BW-DAM {wass_acc}/{config.num_patterns}, Euclidean {eucl_acc}/{config.num_patterns}")
    
    # Generate visualization
    plot_comparison(results, stored_gaussians)
    plt.show()


if __name__ == "__main__":
    main()
