"""
Visualize STAGE behavior on anisotropic line data
==================================================

Simple script to generate one instance of anisotropic data and visualize:
1. The original 2D points
2. The STAGE embedding
3. Comparison with Recanati's spectral ordering
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import rankdata
from stage3 import stage_embedding, evaluate_kendall_abs
from experiment_utils import spectral_ordering


# Fixed parameters
n_points = 500
sigma_parallel = 0.01   # Small noise along curve
sigma_perp = 0.3        # Large noise perpendicular to curve
seed = 42
k = 50

# Generate data: line with anisotropic noise
np.random.seed(seed)
t = np.linspace(0, 1, n_points)
# t = np.random.uniform(0, 1, n_points)
true_order = rankdata(t, method='average')

# Curve: γ(t) = (t, 0)
curve_points = np.column_stack([t, np.zeros(n_points)])

# Anisotropic noise
noise = np.random.normal(
    loc=0, scale=[sigma_parallel, sigma_perp], size=(n_points, 2)
)
X = curve_points + noise

print(f"Data shape: {X.shape}")
print(f"n_points: {n_points}")
print(f"σ_parallel: {sigma_parallel}")
print(f"σ_perp: {sigma_perp}")
# print(f"Anisotropy ratio: {sigma_perp/sigma_parallel:.1f}x")
print()

# Run STAGE
y_stage, stage_order = stage_embedding(X, k=k, embedding="laplacian")
stage_tau = evaluate_kendall_abs(true_order, stage_order)
print(f"STAGE Kendall's τ: {stage_tau:.4f}")

# Run Recanati
sigma_kernel = np.sqrt(sigma_parallel**2 + sigma_perp**2)
recanati_result, sim_mat = spectral_ordering(
    X, sigma=sigma_kernel, return_sim_mat=True
)
recanati_order = rankdata(recanati_result, method='average')
recanati_tau = evaluate_kendall_abs(true_order, recanati_order)
print(f"Recanati Kendall's τ: {recanati_tau:.4f}")
print()

# plot sim_mat
fig_sim = plt.figure(figsize=(10, 8))
plt.imshow(sim_mat, cmap='viridis')
plt.colorbar(label='Similarity')
plt.title('Similarity Matrix')
plt.xlabel('Point index')
plt.ylabel('Point index')
plt.tight_layout()
plt.savefig('similarity_matrix.png', dpi=150, bbox_inches='tight')
print("Saved similarity matrix to 'similarity_matrix.png'")
plt.close(fig_sim)

# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Original 2D points colored by true parameter t
ax = axes[0, 0]
scatter = ax.scatter(
    X[:, 0], X[:, 1], c=t, cmap='viridis', s=30, alpha=0.7
)
ax.set_xlabel('x₁ (along curve)')
ax.set_ylabel('x₂ (perpendicular)')
title = f'Original Data (n={n_points}, σ∥={sigma_parallel}, σ⊥={sigma_perp})'
ax.set_title(title)
ax.set_aspect('equal')
plt.colorbar(
    scatter, ax=ax, label='True parameter t', fraction=0.046, pad=0.04
)

# 2. Original points colored by STAGE embedding
ax = axes[0, 1]
scatter = ax.scatter(
    X[:, 0], X[:, 1], c=y_stage, cmap='viridis', s=30, alpha=0.7
)
ax.set_xlabel('x₁ (along curve)')
ax.set_ylabel('x₂ (perpendicular)')
ax.set_title(f'Colored by STAGE embedding (τ={stage_tau:.3f})')
ax.set_aspect('equal')
plt.colorbar(
    scatter, ax=ax, label='STAGE y', fraction=0.046, pad=0.04
)

# 3. STAGE embedding vs true parameter
ax = axes[1, 0]
ax.scatter(t, y_stage, s=20, alpha=0.6, color='blue')
ax.set_xlabel('True parameter t')
ax.set_ylabel('STAGE embedding y')
ax.set_title(f'STAGE: y vs t (τ={stage_tau:.3f})')
ax.grid(True, alpha=0.3)

# 4. Recanati embedding vs true parameter
ax = axes[1, 1]
ax.scatter(t, recanati_result, s=20, alpha=0.6, color='red')
ax.set_xlabel('True parameter t')
ax.set_ylabel('Recanati embedding')
ax.set_title(f'Recanati: embedding vs t (τ={recanati_tau:.3f})')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('anisotropic_visualization.png', dpi=150, bbox_inches='tight')
print("Saved visualization to 'anisotropic_visualization.png'")
plt.show()
