# %%
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import rankdata
from stage5 import STAGE, evaluate_kendall_abs

n_points = 500
sigma_parallel = 0.01
sigma_perp = 0.3
seed = 42
k = 400

np.random.seed(seed)
t = np.linspace(0, 1, n_points)
true_order = rankdata(t, method='average')
curve_points = np.column_stack([t, np.zeros(n_points)])
noise = np.random.normal(0, [sigma_parallel, sigma_perp], (n_points, 2))
X = curve_points + noise

print(f"Data: n={n_points}, σ∥={sigma_parallel}, σ⊥={sigma_perp}")

# %%
model = STAGE(k=k, embedding="linreg", pca_full_dim=True)
model.fit(X)

stage_tau = evaluate_kendall_abs(true_order, model.order)
print(f"STAGE τ: {stage_tau:.4f}\n")

print("Debugging info:")
print(f"  edges: {model.edges.shape[0]}")
print(f"  tangents_raw: {model.tangents_raw.shape}")
print(f"  tangents (aligned): {model.tangents.shape}")
n_pos = np.sum(model.signs > 0)
n_neg = np.sum(model.signs < 0)
print(f"  signs: {n_pos} positive, {n_neg} negative")
print(f"  y range: [{model.y.min():.3f}, {model.y.max():.3f}]")

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

ax = axes[0, 0]
ax.scatter(X[:, 0], X[:, 1], c=t, cmap='viridis', s=30, alpha=0.7)
ax.set_title('Data (colored by true t)')
ax.set_aspect('equal')

ax = axes[0, 1]
scale = 0.05
kept = model.signs > 0
flipped = model.signs < 0
ax.scatter(X[:, 0], X[:, 1], c=t, cmap='viridis', s=20, alpha=0.3)
if np.any(kept):
    ax.quiver(X[kept, 0], X[kept, 1],
              model.tangents_raw[kept, 0]*scale,
              model.tangents_raw[kept, 1]*scale,
              color='blue', alpha=0.6, width=0.003, scale=1,
              scale_units='xy', angles='xy', label='Kept')
if np.any(flipped):
    ax.quiver(X[flipped, 0], X[flipped, 1],
              model.tangents_raw[flipped, 0]*scale,
              model.tangents_raw[flipped, 1]*scale,
              color='red', alpha=0.6, width=0.003, scale=1,
              scale_units='xy', angles='xy', label='Flipped')
ax.set_title('Raw tangents (before alignment)')
ax.set_aspect('equal')
ax.legend()

ax = axes[0, 2]
ax.scatter(X[:, 0], X[:, 1], c=t, cmap='viridis', s=20, alpha=0.3)
ax.quiver(X[:, 0], X[:, 1],
          model.tangents[:, 0]*scale, model.tangents[:, 1]*scale,
          color='crimson', alpha=0.7, width=0.003, scale=1,
          scale_units='xy', angles='xy')
ax.set_title('Aligned tangents')
ax.set_aspect('equal')

ax = axes[1, 0]
ax.scatter(X[:, 0], X[:, 1], c=model.y, cmap='viridis', s=30, alpha=0.7)
ax.set_title(f'Colored by STAGE (τ={stage_tau:.3f})')
ax.set_aspect('equal')

ax = axes[1, 1]
ax.scatter(t, model.y, s=20, alpha=0.6, color='blue')
ax.set_xlabel('True t')
ax.set_ylabel('STAGE y')
ax.set_title(f'STAGE: y vs t (τ={stage_tau:.3f})')
ax.grid(True, alpha=0.3)

axes[1, 2].axis('off')

plt.tight_layout()
plt.savefig('debug_anisotropic.png', dpi=150, bbox_inches='tight')
print("\nSaved debug_anisotropic.png")
plt.show()

# %%
