#%%from datasets import sine_curve
from datasets import sine_curve
from stage3    import stage_embedding, evaluate_kendall_abs
from sklearn.manifold import TSNE
import time, numpy as np
from stage_original import spectral_lin_reg_knn
from curve import RandomFourierCurve
from scipy.stats import rankdata

curve = RandomFourierCurve(d=200, K=10, alpha=2.3, span=0.25, seed=0)

# --- stretch so that κ_max ≈ 2.0 -----------------------------------
smooth = curve.stretch_to_curvature(kappa_max=2.0)

t, _ = smooth.unit_speed_grid(900)
kappa0 = smooth.curvature(t).max()
print("κ_max:", kappa0)
# reshuffle t
np.random.shuffle(t)

X, true_order = smooth.c(t), rankdata(t, method='average')                         # unpack
# print(true_order)
# add noise
X += np.random.normal(scale=2, size=X.shape)

#%%


#%%
import matplotlib.pyplot as plt

# Create a 5x5 grid of subplots
fig, axes = plt.subplots(10, 10, figsize=(15, 15))
fig.suptitle('Scatter Plots of Coordinate Pairs')

# Generate 25 different coordinate pairs
for i in range(10):
    for j in range(10):
        # Use different coordinate pairs for each subplot
        x_coord = i   # Spread out the x coordinates
        y_coord = j   # Spread out the y coordinates
        axes[i, j].scatter(X[:, x_coord], X[:, y_coord], s=1)
        axes[i, j].set_title(f'({x_coord}, {y_coord})')
        axes[i, j].set_xticks([])
        axes[i, j].set_yticks([])

plt.tight_layout()
plt.savefig('scatter_plot.png')
plt.close()

#%%
# Run and time STAGE embedding
t0 = time.perf_counter()
y_stage, ord_stage = stage_embedding(X, k=50, pca_full_dim=True, embedding="linreg")
stage_time = time.perf_counter() - t0
stage_score = evaluate_kendall_abs(ord_stage, true_order)

# Run and time t-SNE
t0 = time.perf_counter()
y_tsne = TSNE(n_components=1, perplexity=100).fit_transform(X).ravel()
tsne_time = time.perf_counter() - t0
tsne_score = evaluate_kendall_abs(rankdata(y_tsne, method='average'), true_order)

# Run and time original method
# t0 = time.perf_counter()
# y_original = spectral_lin_reg_knn(X, K_NEIGHBORS=50, N_MIN=20)
# orig_time = time.perf_counter() - t0
# orig_score = evaluate_kendall_abs(y_original, true_order)

# Print results in a formatted table
print("\nResults:")
print("-" * 50)
print(f"{'Method':<10} {'Time (s)':<12} {'Kendall Score (abs)':<15}")
print("-" * 50)
print(f"{'STAGE':<10} {stage_time:<12.3f} {stage_score:<15.3f}")
print(f"{'t-SNE':<10} {tsne_time:<12.3f} {tsne_score:<15.3f}")
# print(f"{'Original':<10} {orig_time:<12.3f} {orig_score:<15.3f}")
print("-" * 50)


 # %%
