"""
BW-DAM Retrieval on Synthetic Gaussian Data.

This experiment:
1. Samples N Gaussian distributions from a Wasserstein sphere
2. Perturbs a subset to create initial query states
3. Runs BW-DAM dynamics until convergence
4. Plots mean W₂ distance to original pattern vs iteration

This demonstrates the convergence properties of BW-DAM on controlled synthetic data.
"""

import numpy as np
from numpy.linalg import qr, norm, eigh
import matplotlib.pyplot as plt
from tqdm import tqdm
from dataclasses import dataclass
from typing import Tuple, List


# =============================================================================
# Configuration
# =============================================================================

@dataclass
class SyntheticConfig:
    """Configuration for synthetic retrieval experiment."""
    
    # Gaussian parameters
    dim: int = 20
    lambda_min: float = 0.8
    lambda_max: float = 1.2
    num_patterns: int = 1000
    
    # Wasserstein sphere radius: R = √(2d)
    @property
    def sphere_radius(self) -> float:
        return np.sqrt(2 * self.dim)
    
    # BW-DAM parameters
    beta: float = 0.1
    num_iters: int = 5
    
    # Experiment parameters
    perturb_fraction: float = 0.75  # Fraction of patterns to test
    
    @property
    def perturbation_distance(self) -> float:
        """Perturbation distance = √λ_min (contractive radius)."""
        return np.sqrt(self.lambda_min)
    
    # Random seed
    seed: int = 42


# =============================================================================
# Matrix Operations
# =============================================================================

def matrix_sqrt(A: np.ndarray) -> np.ndarray:
    """Compute matrix square root of symmetric positive semi-definite matrix."""
    eigvals, eigvecs = eigh(A)
    eigvals = np.maximum(eigvals, 0)
    return eigvecs @ np.diag(np.sqrt(eigvals)) @ eigvecs.T


def matrix_sqrt_inv(A: np.ndarray, reg: float = 1e-10) -> np.ndarray:
    """Compute inverse matrix square root of symmetric positive definite matrix."""
    eigvals, eigvecs = eigh(A)
    eigvals = np.maximum(eigvals, reg)
    return eigvecs @ np.diag(1.0 / np.sqrt(eigvals)) @ eigvecs.T


# =============================================================================
# Wasserstein Distance
# =============================================================================

def w2_squared(
    mu1: np.ndarray,
    Sigma1: np.ndarray,
    mu2: np.ndarray,
    Sigma2: np.ndarray,
) -> float:
    """
    Compute squared Wasserstein-2 distance between two Gaussians.
    
    W₂²(N(μ₁,Σ₁), N(μ₂,Σ₂)) = ||μ₁-μ₂||² + B²(Σ₁,Σ₂)
    
    where B² is the squared Bures metric.
    """
    mean_diff_sq = np.sum((mu1 - mu2) ** 2)
    sqrt_Sigma1 = matrix_sqrt(Sigma1)
    inner = sqrt_Sigma1 @ Sigma2 @ sqrt_Sigma1
    sqrt_inner = matrix_sqrt(inner)
    bures_sq = np.trace(Sigma1) + np.trace(Sigma2) - 2 * np.trace(sqrt_inner)
    return mean_diff_sq + max(bures_sq, 0)


def w2_distance(
    mu1: np.ndarray,
    Sigma1: np.ndarray,
    mu2: np.ndarray,
    Sigma2: np.ndarray,
) -> float:
    """Compute Wasserstein-2 distance between two Gaussians."""
    return np.sqrt(w2_squared(mu1, Sigma1, mu2, Sigma2))


# =============================================================================
# Sampling from Wasserstein Sphere
# =============================================================================

def sample_haar_orthogonal(d: int, rng: np.random.Generator) -> np.ndarray:
    """Sample a random orthogonal matrix from the Haar measure."""
    Z = rng.standard_normal((d, d))
    Q, R = qr(Z)
    signs = np.sign(np.diag(R))
    signs[signs == 0] = 1
    return Q * signs


def sample_gaussians_from_sphere(
    config: SyntheticConfig,
    rng: np.random.Generator,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Sample N Gaussian measures from a Wasserstein sphere of radius R centered at δ₀.
    
    Returns:
        means: Array of shape (N, d)
        covariances: Array of shape (N, d, d)
    """
    N = config.num_patterns
    d = config.dim
    R_squared = config.sphere_radius ** 2
    
    means = np.zeros((N, d))
    covariances = np.zeros((N, d, d))
    
    for i in range(N):
        # Sample eigenvalues uniformly from [λ_min, λ_max]
        lambdas = rng.uniform(config.lambda_min, config.lambda_max, size=d)
        tau = np.sum(lambdas)
        
        # Sample orthogonal matrix from Haar measure
        Q = sample_haar_orthogonal(d, rng)
        
        # Construct covariance matrix
        Sigma = Q @ np.diag(lambdas) @ Q.T
        Sigma = (Sigma + Sigma.T) / 2
        covariances[i] = Sigma
        
        # Sample mean on sphere of radius √(R² - τ)
        Z = rng.standard_normal(d)
        scale = np.sqrt(max(R_squared - tau, 0))
        means[i] = scale * Z / norm(Z)
    
    return means, covariances


# =============================================================================
# Perturbation
# =============================================================================

def perturb_gaussian(
    mu: np.ndarray,
    Sigma: np.ndarray,
    target_w2: float,
    rng: np.random.Generator,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Perturb a Gaussian N(μ, Σ) to get N(m, Ω) at approximately target W₂ distance.
    
    Strategy: Split W₂² budget 50/50 between mean and covariance perturbation.
    """
    d = len(mu)
    target_w2_sq = target_w2 ** 2
    
    # Budget allocation
    mean_budget = 0.5 * target_w2_sq
    cov_budget = 0.5 * target_w2_sq
    
    # Mean perturbation: random direction with specified magnitude
    direction = rng.standard_normal(d)
    direction = direction / norm(direction)
    delta_m = np.sqrt(mean_budget) * direction
    m = mu + delta_m
    
    # Covariance perturbation via scaling
    # Bures²(Σ, cΣ) = tr(Σ) * (1 - √c)²
    trace_Sigma = np.trace(Sigma)
    ratio = cov_budget / trace_Sigma
    
    if ratio >= 1:
        c = 0.25
    else:
        sqrt_c = max(1 - np.sqrt(ratio), 0.1)
        c = sqrt_c ** 2
    
    Omega = c * Sigma
    
    # Adjust mean to achieve target distance
    actual_bures_sq = w2_squared(mu, Sigma, mu, Omega)
    required_mean_sq = target_w2_sq - actual_bures_sq
    
    if required_mean_sq > 0:
        current_mean_sq = np.sum(delta_m ** 2)
        if current_mean_sq > 0:
            scale_factor = np.sqrt(required_mean_sq / current_mean_sq)
            delta_m = scale_factor * delta_m
            m = mu + delta_m
    
    return m, Omega


# =============================================================================
# BW-DAM Dynamics
# =============================================================================

def bwdam_step(
    m: np.ndarray,
    Omega: np.ndarray,
    means: np.ndarray,
    covariances: np.ndarray,
    beta: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    One step of BW-DAM update (Φ operator).
    
    Args:
        m: Current mean (d,)
        Omega: Current covariance (d, d)
        means: Stored pattern means (N, d)
        covariances: Stored pattern covariances (N, d, d)
        beta: Temperature parameter
    
    Returns:
        m_new: Updated mean
        Omega_new: Updated covariance
    """
    N = len(means)
    d = len(m)
    
    # Compute W₂² distances to all stored patterns
    D = np.array([
        w2_squared(means[i], covariances[i], m, Omega)
        for i in range(N)
    ])
    
    # Softmax weights with log-sum-exp for stability
    log_weights = -beta * D
    log_weights = log_weights - np.max(log_weights)
    w = np.exp(log_weights)
    w = w / np.sum(w)
    
    # Compute transport map coefficients
    A_list = []
    for i in range(N):
        Sigma_i = covariances[i]
        sqrt_Sigma_i = matrix_sqrt(Sigma_i)
        inner = sqrt_Sigma_i @ Omega @ sqrt_Sigma_i
        inner = (inner + inner.T) / 2
        inner_sqrt_inv = matrix_sqrt_inv(inner)
        A_i = sqrt_Sigma_i @ inner_sqrt_inv @ sqrt_Sigma_i
        A_list.append(A_i)
    
    # Weighted updates
    m_new = np.sum(w[:, np.newaxis] * means, axis=0)
    A_tilde = sum(w[i] * A_list[i] for i in range(N))
    Omega_new = A_tilde @ Omega @ A_tilde.T
    Omega_new = (Omega_new + Omega_new.T) / 2
    
    return m_new, Omega_new


def run_bwdam_dynamics(
    m_init: np.ndarray,
    Omega_init: np.ndarray,
    means: np.ndarray,
    covariances: np.ndarray,
    beta: float,
    target_mu: np.ndarray,
    target_Sigma: np.ndarray,
    num_iters: int,
) -> List[float]:
    """
    Run BW-DAM dynamics for a fixed number of iterations.
    
    Returns:
        List of W₂ distances to target pattern at each iteration.
    """
    m, Omega = m_init.copy(), Omega_init.copy()
    
    # Initial distance
    distance_history = [w2_distance(m, Omega, target_mu, target_Sigma)]
    
    for _ in range(num_iters):
        m, Omega = bwdam_step(m, Omega, means, covariances, beta)
        dist = w2_distance(m, Omega, target_mu, target_Sigma)
        distance_history.append(dist)
    
    return distance_history


# =============================================================================
# Main Experiment
# =============================================================================

def run_experiment(config: SyntheticConfig) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Run the full synthetic retrieval experiment.
    
    Returns:
        histories: Array of shape (n_perturbed, num_iters + 1)
        mean_distances: Mean W₂ distance at each iteration
        std_distances: Std of W₂ distance at each iteration
    """
    rng = np.random.default_rng(config.seed)
    
    print("=" * 70)
    print("BW-DAM Retrieval Dynamics on Synthetic Data")
    print("=" * 70)
    print(f"\nConfiguration:")
    print(f"  Dimension (d):          {config.dim}")
    print(f"  Number of patterns (N): {config.num_patterns}")
    print(f"  Sphere radius (R):      {config.sphere_radius:.4f}")
    print(f"  Spectral bounds:        [{config.lambda_min}, {config.lambda_max}]")
    print(f"  Temperature (β):        {config.beta}")
    print(f"  Perturbation distance:  {config.perturbation_distance:.4f}")
    print(f"  Number of iterations:   {config.num_iters}")
    
    # Sample Gaussians
    print(f"\n[1/4] Sampling {config.num_patterns} Gaussians from Wasserstein sphere...")
    means, covariances = sample_gaussians_from_sphere(config, rng)
    
    # Select patterns to perturb
    n_perturbed = int(config.perturb_fraction * config.num_patterns)
    print(f"[2/4] Selecting {n_perturbed} patterns to perturb...")
    perturbed_indices = rng.choice(config.num_patterns, size=n_perturbed, replace=False)
    
    # Create perturbations
    print(f"[3/4] Creating perturbations...")
    initial_states = []
    for idx in perturbed_indices:
        m_init, Omega_init = perturb_gaussian(
            means[idx], covariances[idx],
            config.perturbation_distance, rng
        )
        initial_states.append((m_init, Omega_init))
    
    # Run dynamics
    print(f"[4/4] Running BW-DAM dynamics...")
    all_histories = []
    
    for i, idx in enumerate(tqdm(perturbed_indices, desc="  Progress")):
        m_init, Omega_init = initial_states[i]
        history = run_bwdam_dynamics(
            m_init, Omega_init,
            means, covariances,
            config.beta,
            means[idx], covariances[idx],
            config.num_iters
        )
        all_histories.append(history)
    
    # Compute statistics
    histories = np.array(all_histories)
    mean_distances = np.mean(histories, axis=0)
    std_distances = np.std(histories, axis=0)
    
    print(f"\nResults:")
    print(f"  Initial mean W₂: {mean_distances[0]:.4f} ± {std_distances[0]:.4f}")
    print(f"  Final mean W₂:   {mean_distances[-1]:.6f} ± {std_distances[-1]:.6f}")
    
    return histories, mean_distances, std_distances


def plot_convergence(
    mean_distances: np.ndarray,
    std_distances: np.ndarray,
    num_iters: int,
    output_path: str = "bwdam_synthetic_convergence.png",
) -> None:
    """Generate convergence plot."""
    iterations = np.arange(num_iters + 1)
    
    plt.figure(figsize=(10, 6))
    plt.plot(iterations, mean_distances, "b-", linewidth=2, label="Mean $W_2$ distance")
    plt.fill_between(
        iterations,
        mean_distances - std_distances,
        mean_distances + std_distances,
        alpha=0.3, color="blue", label="± 1 std"
    )
    
    plt.xlabel("Iteration", fontsize=14)
    plt.ylabel("$W_2$ distance to original pattern", fontsize=14)
    plt.legend(fontsize=11, loc="upper right")
    plt.grid(True, alpha=0.3)
    plt.xlim([0, num_iters])
    plt.ylim(bottom=0)
    plt.xticks(iterations)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    print(f"\nSaved: {output_path}")


def main():
    """Run the experiment and generate plot."""
    config = SyntheticConfig()
    histories, mean_dist, std_dist = run_experiment(config)
    plot_convergence(mean_dist, std_dist, config.num_iters)
    plt.show()
    return histories, mean_dist, std_dist


if __name__ == "__main__":
    main()
