from stage3    import stage_embedding, evaluate_kendall_abs
from sklearn.manifold import TSNE
import time, numpy as np
from stage_original import spectral_lin_reg_knn
from curve import RandomFourierCurve
from scipy.stats import rankdata
import matplotlib.pyplot as plt

curve = RandomFourierCurve(d=200, K=10, alpha=2.3, span=0.25, seed=10)

# --- stretch so that κ_max ≈ 2.0 -----------------------------------
smooth = curve.stretch_to_curvature(kappa_max=2.0)

t, _ = smooth.unit_speed_grid(900)
kappa0 = smooth.curvature(t).max()
print("κ_max:", kappa0)
# reshuffle t
np.random.shuffle(t)

X, rank = smooth.c(t), rankdata(t, method='average')                         # unpack
# print(true_order)
# add noise
X_noisy = X + np.random.normal(scale=.1, size=X.shape)

rank = np.int32(rank) - 1

order = np.argsort(rank)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(X[order, 0], X[order, 1], X[order, 2])
ax.set_title("Example Fourier Curve")
plt.savefig("fourier_curve_3d.png")

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_title("Example Noisy Fourier Curve")
p = ax.scatter(X_noisy[:, 0], X_noisy[:, 1], X_noisy[:, 2], c = rank)
fig.colorbar(p, ax = ax, label = "Rank")
plt.savefig("noisy_fourier_curve_3d.png")


#%%
X, true_order = smooth.c(t), rankdata(t, method='average')                         # unpack
# print(true_order)
# add noise
X_noisy = X + np.random.normal(scale=.5, size=X.shape)

fig, axes = plt.subplots(10, 10, figsize=(15, 15))
fig.suptitle('Scatter Plots of Noisy Coordinate Pairs')

for i in range(10):
    for j in range(10):
        # Use different coordinate pairs for each subplot
        x_coord = i   # Spread out the x coordinates
        y_coord = j   # Spread out the y coordinates
        axes[i, j].scatter(X_noisy[:, x_coord], X_noisy[:, y_coord], s=1)
        axes[i, j].set_title(f'({x_coord}, {y_coord})')
        axes[i, j].set_xticks([])
        axes[i, j].set_yticks([])

plt.tight_layout()
plt.savefig('scatter_plot_noisy.png')
plt.close()

fig, axes = plt.subplots(10, 10, figsize=(15, 15))
fig.suptitle('Scatter Plots of True Coordinate Pairs')

for i in range(10):
    for j in range(10):
        # Use different coordinate pairs for each subplot
        x_coord = i   # Spread out the x coordinates
        y_coord = j   # Spread out the y coordinates
        axes[i, j].scatter(X[:, x_coord], X[:, y_coord], s=1)
        axes[i, j].set_title(f'({x_coord}, {y_coord})')
        axes[i, j].set_xticks([])
        axes[i, j].set_yticks([])

plt.tight_layout()
plt.savefig('scatter_plot_true.png')
plt.close()

embedding = stage_embedding(X_noisy, k=50, pca_full_dim=True, embedding="linreg", return_intermediates = True)
y = embedding.y
order = embedding.order
tangents = embedding.tangents
edges = embedding.graph_edges

X_denoised = np.zeros(X.shape)
for j in range(900):
    neighbors = np.ndarray.flatten(edges[edges[:, 1] == j, 0])
    if len(neighbors) == 0:
        X_denoised[j] = X[j]
        continue
    for i in neighbors:
        X_denoised[j] += X[i] + np.outer(tangents[j], tangents[j]) @ (X[j] - X[i])
    X_denoised[j] = X_denoised[j]/len(neighbors)

fig, axes = plt.subplots(10, 10, figsize=(15, 15))
fig.suptitle('Scatter Plots of Denoised Coordinate Pairs')

for i in range(10):
    for j in range(10):
        # Use different coordinate pairs for each subplot
        x_coord = i   # Spread out the x coordinates
        y_coord = j   # Spread out the y coordinates
        axes[i, j].scatter(X_denoised[:, x_coord], X_denoised[:, y_coord], s=1)
        axes[i, j].set_title(f'({x_coord}, {y_coord})')
        axes[i, j].set_xticks([])
        axes[i, j].set_yticks([])

plt.tight_layout()
plt.savefig('scatter_plot_denoised.png')
plt.close()