"""
BW-DAM Retrieval on Gaussian Image Embeddings (VAE).

This experiment:
1. Trains a full-covariance VAE on CIFAR-10 images
2. Encodes images as Gaussian distributions in latent space
3. Creates masked (corrupted) versions of images
4. Runs three DAM variants for retrieval:
   - BW-DAM: Bures-Wasserstein geometry on Gaussians
   - Euclidean DAM: Flat geometry on vectorized Gaussian parameters
   - Pixel DAM: Flat geometry on raw pixel values
5. Compares retrieval accuracy across methods and β values

Requirements:
    - torch
    - torchvision
    - numpy
    - matplotlib
"""

import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore")


# =============================================================================
# Configuration
# =============================================================================

@dataclass
class VAEConfig:
    """Configuration for VAE model and training."""
    
    latent_dim: int = 32
    hidden_dims: List[int] = field(default_factory=lambda: [32, 64, 128, 256, 512])
    
    # Training
    epochs: int = 200
    learning_rate: float = 1e-3
    batch_size: int = 16


@dataclass
class ExperimentConfig:
    """Configuration for DAM retrieval experiment."""
    
    # Dataset
    num_images: int = 100
    image_size: int = 64
    data_dir: str = "./data"
    output_dir: str = "./output"
    
    # Perturbation
    mask_fraction: float = 0.20
    mask_value: float = 0.0
    
    # DAM dynamics
    max_iters: int = 100
    
    # Beta sweep
    beta_values: List[float] = field(
        default_factory=lambda: [0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 30.0, 100.0]
    )
    num_trials: int = 5
    
    # Random seed
    seed: int = 42


# =============================================================================
# Full Covariance VAE
# =============================================================================

class FullCovarianceVAE(nn.Module):
    """VAE with full covariance Gaussian posterior."""
    
    def __init__(
        self,
        in_channels: int = 3,
        latent_dim: int = 32,
        hidden_dims: Optional[List[int]] = None,
    ):
        super().__init__()
        self.latent_dim = latent_dim
        
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]
        
        # Encoder
        modules = []
        for h_dim in hidden_dims:
            modules.append(nn.Sequential(
                nn.Conv2d(in_channels, h_dim, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(h_dim),
                nn.LeakyReLU()
            ))
            in_channels = h_dim
        
        self.encoder = nn.Sequential(*modules)
        self.encoder_out_dim = hidden_dims[-1] * 4
        
        # Latent space
        self.fc_mu = nn.Linear(self.encoder_out_dim, latent_dim)
        self.fc_log_sigma = nn.Linear(self.encoder_out_dim, latent_dim)
        
        nn.init.zeros_(self.fc_log_sigma.weight)
        nn.init.zeros_(self.fc_log_sigma.bias)
        
        # Off-diagonal elements of Cholesky factor
        self.n_off_diag = (latent_dim * (latent_dim - 1)) // 2
        self.fc_L_offdiag = nn.Linear(self.encoder_out_dim, self.n_off_diag)
        
        nn.init.normal_(self.fc_L_offdiag.weight, mean=0, std=0.01)
        nn.init.zeros_(self.fc_L_offdiag.bias)
        
        tri_indices = torch.tril_indices(row=latent_dim, col=latent_dim, offset=-1)
        self.register_buffer("tri_row_idx", tri_indices[0])
        self.register_buffer("tri_col_idx", tri_indices[1])
        
        # Decoder
        hidden_dims_rev = hidden_dims[::-1]
        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
        
        modules = []
        for i in range(len(hidden_dims_rev) - 1):
            modules.append(nn.Sequential(
                nn.ConvTranspose2d(
                    hidden_dims_rev[i], hidden_dims_rev[i + 1],
                    kernel_size=3, stride=2, padding=1, output_padding=1
                ),
                nn.BatchNorm2d(hidden_dims_rev[i + 1]),
                nn.LeakyReLU()
            ))
        
        self.decoder = nn.Sequential(*modules)
        
        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(
                hidden_dims_rev[-1], hidden_dims_rev[-1],
                kernel_size=3, stride=2, padding=1, output_padding=1
            ),
            nn.BatchNorm2d(hidden_dims_rev[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims_rev[-1], 3, kernel_size=3, padding=1),
            nn.Tanh()
        )
    
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Encode image to latent parameters."""
        result = self.encoder(x)
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu(result)
        log_sigma = self.fc_log_sigma(result)
        L_offdiag = self.fc_L_offdiag(result)
        return mu, log_sigma, L_offdiag
    
    def construct_L(self, log_sigma: torch.Tensor, L_offdiag: torch.Tensor) -> torch.Tensor:
        """Construct Cholesky factor from parameters."""
        batch_size = log_sigma.size(0)
        device = log_sigma.device
        L = torch.zeros(batch_size, self.latent_dim, self.latent_dim, device=device)
        L[:, self.tri_row_idx, self.tri_col_idx] = L_offdiag
        L += torch.diag_embed(torch.exp(log_sigma))
        return L
    
    def get_covariance(self, L: torch.Tensor, min_eigenvalue: float = 1e-4) -> torch.Tensor:
        """Compute covariance Σ = LL^T."""
        Sigma = torch.bmm(L, L.transpose(1, 2))
        Sigma = Sigma + min_eigenvalue * torch.eye(
            self.latent_dim, device=Sigma.device
        ).unsqueeze(0)
        return Sigma
    
    def reparameterize(self, mu: torch.Tensor, L: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Reparameterization trick."""
        batch_size = mu.size(0)
        eps = torch.randn(batch_size, self.latent_dim, device=mu.device)
        z = mu + torch.bmm(L, eps.unsqueeze(-1)).squeeze(-1)
        return eps, z
    
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """Decode latent to image."""
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        return self.final_layer(result)
    
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Forward pass."""
        mu, log_sigma, L_offdiag = self.encode(x)
        L = self.construct_L(log_sigma, L_offdiag)
        eps, z = self.reparameterize(mu, L)
        reconstruction = self.decode(z)
        return {
            "reconstruction": reconstruction,
            "mu": mu,
            "log_sigma": log_sigma,
            "L": L,
            "Sigma": self.get_covariance(L),
            "z": z,
            "eps": eps
        }


def vae_loss(recon_x, x, mu, log_sigma, eps, z):
    """VAE loss with full covariance."""
    recon_loss = nn.functional.mse_loss(recon_x, x, reduction="sum")
    log_pz = -0.5 * torch.sum(z ** 2 + math.log(2 * math.pi))
    log_qz = -torch.sum(0.5 * (eps ** 2 + math.log(2 * math.pi)) + log_sigma)
    return recon_loss - log_pz + log_qz


def train_vae(
    model: FullCovarianceVAE,
    images: torch.Tensor,
    device: torch.device,
    config: VAEConfig,
    verbose: bool = False,
) -> FullCovarianceVAE:
    """Train VAE on images."""
    model.train()
    model = model.to(device)
    images = images.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    n_samples = images.shape[0]
    n_batches = (n_samples + config.batch_size - 1) // config.batch_size
    
    for epoch in range(config.epochs):
        epoch_loss = 0.0
        indices = torch.randperm(n_samples)
        
        for batch_idx in range(n_batches):
            start_idx = batch_idx * config.batch_size
            end_idx = min(start_idx + config.batch_size, n_samples)
            batch_indices = indices[start_idx:end_idx]
            batch = images[batch_indices]
            
            optimizer.zero_grad()
            output = model(batch)
            loss = vae_loss(
                output["reconstruction"], batch,
                output["mu"], output["log_sigma"],
                output["eps"], output["z"]
            )
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += loss.item()
        
        if verbose and (epoch + 1) % 50 == 0:
            print(f"  Epoch {epoch+1}/{config.epochs}, Loss: {epoch_loss/n_samples:.4f}")
    
    model.eval()
    return model


# =============================================================================
# Data Loading
# =============================================================================

def load_cifar10_images(
    num_images: int,
    image_size: int,
    random_seed: int,
    data_dir: str,
) -> torch.Tensor:
    """Load random subset of CIFAR-10 images."""
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    dataset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, transform=transform, download=True
    )
    
    np.random.seed(random_seed)
    indices = np.random.choice(len(dataset), size=num_images, replace=False)
    
    images = [dataset[idx][0] for idx in indices]
    return torch.stack(images)


# =============================================================================
# Wasserstein Distance Utilities
# =============================================================================

def matrix_sqrt_np(A: np.ndarray) -> np.ndarray:
    """Matrix square root."""
    eigenvalues, eigenvectors = np.linalg.eigh(A)
    eigenvalues = np.maximum(eigenvalues, 0)
    return eigenvectors @ np.diag(np.sqrt(eigenvalues)) @ eigenvectors.T


def bures_metric_squared(Sigma1: np.ndarray, Sigma2: np.ndarray) -> float:
    """Squared Bures metric."""
    sqrt_Sigma1 = matrix_sqrt_np(Sigma1)
    inner = sqrt_Sigma1 @ Sigma2 @ sqrt_Sigma1
    sqrt_inner = matrix_sqrt_np(inner)
    return np.trace(Sigma1) + np.trace(Sigma2) - 2 * np.trace(sqrt_inner)


def w2_squared_full(
    mu1: np.ndarray,
    Sigma1: np.ndarray,
    mu2: np.ndarray,
    Sigma2: np.ndarray,
) -> float:
    """Squared W₂ distance for full covariance Gaussians."""
    return np.sum((mu1 - mu2) ** 2) + bures_metric_squared(Sigma1, Sigma2)


def w2_distance_full(
    mu1: np.ndarray,
    Sigma1: np.ndarray,
    mu2: np.ndarray,
    Sigma2: np.ndarray,
) -> float:
    """W₂ distance for full covariance Gaussians."""
    return np.sqrt(max(0, w2_squared_full(mu1, Sigma1, mu2, Sigma2)))


def optimal_transport_map(Omega: np.ndarray, Sigma_i: np.ndarray) -> np.ndarray:
    """Compute optimal transport map coefficient."""
    sqrt_Omega = matrix_sqrt_np(Omega)
    eigenvalues, eigenvectors = np.linalg.eigh(Omega)
    eigenvalues = np.maximum(eigenvalues, 1e-10)
    inv_sqrt_Omega = eigenvectors @ np.diag(1.0 / np.sqrt(eigenvalues)) @ eigenvectors.T
    inner = sqrt_Omega @ Sigma_i @ sqrt_Omega
    sqrt_inner = matrix_sqrt_np(inner)
    return inv_sqrt_Omega @ sqrt_inner @ inv_sqrt_Omega


# =============================================================================
# Pixel Masking
# =============================================================================

def create_random_pixel_mask(
    image_size: int,
    mask_fraction: float,
    seed: Optional[int] = None,
) -> np.ndarray:
    """Create random pixel mask."""
    rng = np.random.default_rng(seed)
    total_pixels = image_size * image_size
    n_pixels_to_mask = max(1, int(total_pixels * mask_fraction)) if mask_fraction > 0 else 0
    
    mask_flat = np.ones(total_pixels)
    if n_pixels_to_mask > 0:
        masked_indices = rng.choice(total_pixels, size=n_pixels_to_mask, replace=False)
        mask_flat[masked_indices] = 0
    
    return mask_flat.reshape(image_size, image_size)


def apply_mask_to_image(
    image: torch.Tensor,
    mask: np.ndarray,
    mask_value: float,
) -> torch.Tensor:
    """Apply mask to image."""
    mask_tensor = torch.from_numpy(mask).float()
    masked_image = image.clone()
    for c in range(image.shape[0]):
        masked_image[c] = image[c] * mask_tensor + mask_value * (1 - mask_tensor)
    return masked_image


# =============================================================================
# DAM Dynamics
# =============================================================================

def bwdam_step_full(
    m: np.ndarray,
    Omega: np.ndarray,
    stored_gaussians: List[Tuple[np.ndarray, np.ndarray]],
    beta: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """One step of BW-DAM for full covariance Gaussians."""
    d = len(m)
    
    D_sq = np.array([
        w2_squared_full(mu_i, Sigma_i, m, Omega)
        for mu_i, Sigma_i in stored_gaussians
    ])
    
    log_w = -beta * D_sq
    log_w = log_w - np.max(log_w)
    w = np.exp(log_w)
    w = w / np.sum(w)
    
    m_new = sum(w_i * mu_i for w_i, (mu_i, _) in zip(w, stored_gaussians))
    
    A_maps = [optimal_transport_map(Omega, Sigma_i) for _, Sigma_i in stored_gaussians]
    A_weighted = sum(w_i * A_i for w_i, A_i in zip(w, A_maps))
    Omega_new = A_weighted @ Omega @ A_weighted.T
    
    Omega_new = (Omega_new + Omega_new.T) / 2
    eigenvalues = np.linalg.eigvalsh(Omega_new)
    if np.min(eigenvalues) < 1e-10:
        Omega_new = Omega_new + (1e-10 - np.min(eigenvalues) + 1e-6) * np.eye(d)
    
    return m_new, Omega_new, w


def find_fixed_point_bwdam(
    mu_init: np.ndarray,
    Sigma_init: np.ndarray,
    stored_gaussians: List[Tuple[np.ndarray, np.ndarray]],
    beta: float,
    max_iters: int,
    tol: float = 1e-10,
) -> Tuple[np.ndarray, np.ndarray]:
    """Run BW-DAM until convergence."""
    m, Omega = mu_init.copy(), Sigma_init.copy()
    
    for _ in range(max_iters):
        m_new, Omega_new, _ = bwdam_step_full(m, Omega, stored_gaussians, beta)
        
        if w2_distance_full(m, Omega, m_new, Omega_new) < tol:
            return m_new, Omega_new
        
        m, Omega = m_new, Omega_new
    
    return m, Omega


def euclidean_dam_step(
    xi: np.ndarray,
    X: np.ndarray,
    beta: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """One step of Euclidean DAM."""
    scores = beta * (X.T @ xi)
    scores = scores - np.max(scores)
    w = np.exp(scores)
    w = w / np.sum(w)
    return X @ w, w


def find_fixed_point_euclidean(
    xi_init: np.ndarray,
    X: np.ndarray,
    beta: float,
    max_iters: int,
    tol: float = 1e-10,
) -> np.ndarray:
    """Run Euclidean DAM until convergence."""
    xi = xi_init.copy()
    
    for _ in range(max_iters):
        xi_new, _ = euclidean_dam_step(xi, X, beta)
        
        if np.linalg.norm(xi_new - xi) < tol:
            return xi_new
        
        xi = xi_new
    
    return xi


# =============================================================================
# Gaussian Vectorization
# =============================================================================

def vectorize_gaussian(mu: np.ndarray, Sigma: np.ndarray) -> np.ndarray:
    """Vectorize Gaussian parameters."""
    d = len(mu)
    return np.concatenate([mu, Sigma[np.tril_indices(d)]])


def unvectorize_gaussian(xi: np.ndarray, d: int) -> Tuple[np.ndarray, np.ndarray]:
    """Unvectorize to Gaussian parameters."""
    mu = xi[:d]
    Sigma = np.zeros((d, d))
    Sigma[np.tril_indices(d)] = xi[d:]
    Sigma = Sigma + Sigma.T - np.diag(np.diag(Sigma))
    
    eigenvalues = np.linalg.eigvalsh(Sigma)
    if np.min(eigenvalues) < 1e-10:
        Sigma = Sigma + (1e-10 - np.min(eigenvalues) + 1e-6) * np.eye(d)
    
    return mu, Sigma


# =============================================================================
# Single Trial
# =============================================================================

def run_single_trial(
    images: torch.Tensor,
    model: FullCovarianceVAE,
    stored_gaussians: List[Tuple[np.ndarray, np.ndarray]],
    X_gaussian: np.ndarray,
    X_pixel: np.ndarray,
    mask_fraction: float,
    mask_value: float,
    beta: float,
    max_iters: int,
    latent_dim: int,
    trial_seed: int,
    device: torch.device,
) -> Dict[str, float]:
    """Run single trial with all three DAM methods."""
    num_images = len(images)
    img_size = images[0].shape[1]
    
    # Create masked images
    masked_images = []
    for idx in range(num_images):
        mask = create_random_pixel_mask(img_size, mask_fraction, seed=trial_seed + idx)
        masked_image = apply_mask_to_image(images[idx], mask, mask_value)
        masked_images.append(masked_image)
    
    wass_correct = 0
    eucl_correct = 0
    pixel_correct = 0
    
    for idx in range(num_images):
        # Encode masked image
        model.eval()
        with torch.no_grad():
            masked_img = masked_images[idx].unsqueeze(0).to(device)
            mu_enc, log_sigma, L_offdiag = model.encode(masked_img)
            L = model.construct_L(log_sigma, L_offdiag)
            Sigma_enc = model.get_covariance(L)
        
        mu_init = mu_enc.cpu().numpy().squeeze()
        Sigma_init = Sigma_enc.cpu().numpy().squeeze()
        
        # BW-DAM
        mu_fp_wass, Sigma_fp_wass = find_fixed_point_bwdam(
            mu_init, Sigma_init, stored_gaussians, beta, max_iters
        )
        
        # Euclidean DAM
        xi_init_gaussian = vectorize_gaussian(mu_init, Sigma_init)
        xi_fp_gaussian = find_fixed_point_euclidean(
            xi_init_gaussian, X_gaussian, beta, max_iters
        )
        mu_fp_eucl, Sigma_fp_eucl = unvectorize_gaussian(xi_fp_gaussian, latent_dim)
        
        # Pixel DAM
        xi_init_pixel = masked_images[idx].numpy().flatten()
        xi_fp_pixel = find_fixed_point_euclidean(xi_init_pixel, X_pixel, beta, max_iters)
        
        # Determine retrieved patterns
        w2_distances_wass = [
            w2_distance_full(mu_fp_wass, Sigma_fp_wass, mu_i, Sigma_i)
            for mu_i, Sigma_i in stored_gaussians
        ]
        retrieved_wass = int(np.argmin(w2_distances_wass))
        
        w2_distances_eucl = [
            w2_distance_full(mu_fp_eucl, Sigma_fp_eucl, mu_i, Sigma_i)
            for mu_i, Sigma_i in stored_gaussians
        ]
        retrieved_eucl = int(np.argmin(w2_distances_eucl))
        
        pixel_distances = [
            np.linalg.norm(xi_fp_pixel - images[j].numpy().flatten())
            for j in range(num_images)
        ]
        retrieved_pixel = int(np.argmin(pixel_distances))
        
        if retrieved_wass == idx:
            wass_correct += 1
        if retrieved_eucl == idx:
            eucl_correct += 1
        if retrieved_pixel == idx:
            pixel_correct += 1
    
    return {
        "wass_accuracy": wass_correct / num_images * 100,
        "eucl_accuracy": eucl_correct / num_images * 100,
        "pixel_accuracy": pixel_correct / num_images * 100,
    }


# =============================================================================
# Main Experiment
# =============================================================================

def run_beta_vs_accuracy_experiment(
    exp_config: ExperimentConfig,
    vae_config: VAEConfig,
    device: torch.device,
) -> Dict:
    """Run beta vs accuracy experiment for all three DAM methods."""
    
    Path(exp_config.output_dir).mkdir(parents=True, exist_ok=True)
    
    results = {
        "beta_values": exp_config.beta_values,
        "wass_mean": [], "wass_std": [],
        "eucl_mean": [], "eucl_std": [],
        "pixel_mean": [], "pixel_std": [],
    }
    
    print("=" * 70)
    print("BW-DAM Image Retrieval: β vs Accuracy")
    print("=" * 70)
    
    for beta in exp_config.beta_values:
        print(f"\nβ = {beta}")
        wass_accs, eucl_accs, pixel_accs = [], [], []
        
        for trial in range(exp_config.num_trials):
            trial_seed = exp_config.seed + trial * 1000
            print(f"  Trial {trial + 1}/{exp_config.num_trials}...", end=" ")
            
            # Load fresh images
            images = load_cifar10_images(
                exp_config.num_images, exp_config.image_size,
                trial_seed, exp_config.data_dir
            )
            
            # Train VAE
            model = FullCovarianceVAE(
                in_channels=3, latent_dim=vae_config.latent_dim
            ).to(device)
            model = train_vae(model, images, device, vae_config, verbose=False)
            
            # Encode all images
            model.eval()
            with torch.no_grad():
                images_device = images.to(device)
                mu_all, log_sigma_all, L_offdiag_all = model.encode(images_device)
                L_all = model.construct_L(log_sigma_all, L_offdiag_all)
                Sigma_all = model.get_covariance(L_all)
            
            mu_all_np = mu_all.cpu().numpy()
            Sigma_all_np = Sigma_all.cpu().numpy()
            
            # Prepare memories
            stored_gaussians = [
                (mu_all_np[i], Sigma_all_np[i])
                for i in range(exp_config.num_images)
            ]
            X_gaussian = np.vstack([
                vectorize_gaussian(mu, Sigma) for mu, Sigma in stored_gaussians
            ]).T
            X_pixel = np.vstack([
                images[i].numpy().flatten() for i in range(exp_config.num_images)
            ]).T
            
            # Run trial
            trial_results = run_single_trial(
                images, model, stored_gaussians, X_gaussian, X_pixel,
                exp_config.mask_fraction, exp_config.mask_value,
                beta, exp_config.max_iters, vae_config.latent_dim,
                trial_seed, device
            )
            
            wass_accs.append(trial_results["wass_accuracy"])
            eucl_accs.append(trial_results["eucl_accuracy"])
            pixel_accs.append(trial_results["pixel_accuracy"])
            
            print(f"BW: {trial_results['wass_accuracy']:.1f}%, "
                  f"Eu: {trial_results['eucl_accuracy']:.1f}%, "
                  f"Px: {trial_results['pixel_accuracy']:.1f}%")
        
        results["wass_mean"].append(np.mean(wass_accs))
        results["wass_std"].append(np.std(wass_accs))
        results["eucl_mean"].append(np.mean(eucl_accs))
        results["eucl_std"].append(np.std(eucl_accs))
        results["pixel_mean"].append(np.mean(pixel_accs))
        results["pixel_std"].append(np.std(pixel_accs))
    
    return results


def plot_beta_vs_accuracy(
    results: Dict,
    output_path: str = "bwdam_images_accuracy.png",
) -> None:
    """Plot accuracy vs β for all three methods."""
    fig, ax = plt.subplots(figsize=(8, 5))
    
    beta_values = results["beta_values"]
    
    # Reference lines
    ax.axhline(y=50, color="gray", linestyle="--", linewidth=1, alpha=0.7)
    ax.axhline(y=90, color="gray", linestyle=":", linewidth=1, alpha=0.7)
    
    # Plot with error bars
    ax.errorbar(
        beta_values, results["wass_mean"], yerr=results["wass_std"],
        fmt="b-o", linewidth=2, markersize=8, capsize=5, capthick=2,
        label="BW-DAM"
    )
    ax.errorbar(
        beta_values, results["eucl_mean"], yerr=results["eucl_std"],
        fmt="r--s", linewidth=2, markersize=8, capsize=5, capthick=2,
        label="Euclidean DAM"
    )
    ax.errorbar(
        beta_values, results["pixel_mean"], yerr=results["pixel_std"],
        fmt="g-.^", linewidth=2, markersize=8, capsize=5, capthick=2,
        label="Pixel DAM"
    )
    
    ax.set_xlabel(r"$\beta$ (inverse temperature)", fontsize=12)
    ax.set_ylabel("Retrieval Accuracy (%)", fontsize=12)
    ax.set_ylim(0, 105)
    ax.set_xscale("log")
    ax.legend(fontsize=10, loc="upper left", framealpha=0.9)
    ax.grid(True, alpha=0.3, which="major")
    ax.grid(True, which="minor", alpha=0.15)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
    print(f"\nSaved: {output_path}")


# =============================================================================
# Main
# =============================================================================

def main():
    """Run the full image retrieval experiment."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    exp_config = ExperimentConfig()
    vae_config = VAEConfig()
    
    results = run_beta_vs_accuracy_experiment(exp_config, vae_config, device)
    plot_beta_vs_accuracy(results)
    plt.show()


if __name__ == "__main__":
    main()
