import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import rankdata
from stage5 import STAGE, evaluate_kendall_abs
from experiment_utils import spectral_ordering_robust

np.random.seed(42)

n_points = 500
sigma_parallel = 0.001
k = 400
n_runs = 5

sigma_perp_values = np.linspace(0.1, 0.35, 10)
stage_means = []
stage_stds = []
recanati_means = []
recanati_stds = []

for sigma_perp in sigma_perp_values:
    stage_taus = []
    recanati_taus = []
    
    for run in range(n_runs):
        np.random.seed(run)
        t = np.linspace(0, 1, n_points)
        true_order = rankdata(t, method='average')
        curve_points = np.column_stack([t, np.zeros(n_points)])
        noise = np.random.normal(0, [sigma_parallel, sigma_perp], (n_points, 2))
        X = curve_points + noise
        
        model = STAGE(k=k, embedding="linreg", pca_full_dim=True)
        model.fit(X)
        stage_tau = evaluate_kendall_abs(true_order, model.order)
        stage_taus.append(stage_tau)
        
        sigma = np.sqrt(sigma_parallel**2 + sigma_perp**2)
        recanati_result = spectral_ordering_robust(X, sigma=sigma)
        # recanati_order = rankdata(recanati_result, method='average')
        recanati_order = np.argsort(recanati_result)    
        recanati_tau = evaluate_kendall_abs(true_order, recanati_order)
        recanati_taus.append(recanati_tau)
    
    stage_means.append(np.mean(stage_taus))
    stage_stds.append(np.std(stage_taus))
    recanati_means.append(np.mean(recanati_taus))
    recanati_stds.append(np.std(recanati_taus))
    
    print(f"σ_perp={sigma_perp:.2f}: STAGE={np.mean(stage_taus):.3f}±{np.std(stage_taus):.3f}, Recanati={np.mean(recanati_taus):.3f}±{np.std(recanati_taus):.3f}")

plt.figure(figsize=(8, 5))
plt.errorbar(sigma_perp_values, stage_means, yerr=stage_stds, marker='o', label='STAGE', capsize=4)
plt.errorbar(sigma_perp_values, recanati_means, yerr=recanati_stds, marker='s', label='Recanati', capsize=4)
plt.xlabel('σ_perp')
plt.ylabel('Kendall τ')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("stage_recanati_anisotropy.png")
# plt.show()

