import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import kendalltau, rankdata
from sklearn.manifold import TSNE
from util_functions import spectral_lin_reg, random_fourier_curve
from curve import RandomFourierCurve

###########################
# High-Dimension Examples #
###########################
# Example curve plotted in 3 dimensions

pts, noisy_pts, true_rank = random_fourier_curve(n_pts=1000, d=15, K=3, alpha = 2.5, noise = .1)
order = np.argsort(true_rank)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(pts[order, 0], pts[order, 1], pts[order, 2])
ax.set_title("Example Fourier Curve")
plt.show()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_title("Example Fourier Curve")
p = ax.scatter(noisy_pts[:, 0], noisy_pts[:, 1], noisy_pts[:, 2], c = true_rank)
fig.colorbar(p, ax = ax, label = "Rank")
plt.show()



DIM_LIST = [10, 50, 100, 200]
NOISE_STD_LIST = [1, 2, 5, 10]
N_RUNS = 20
N_POINTS_LIST = [1000, 3000]
np.random.seed(100)
for N_POINTS in N_POINTS_LIST:
    for d in DIM_LIST:
        for NOISE_STD in NOISE_STD_LIST:
            taus = np.empty((N_RUNS, 2))
            for i in np.arange(N_RUNS):
                _, X, true_rank_indices = random_fourier_curve(n_pts = N_POINTS, d = d, K = 8, noise = NOISE_STD/np.sqrt(d))
                n_points, n_dims = X.shape

                t_SNE_rank_indices = rankdata(np.ndarray.flatten(TSNE(n_components = 1, perplexity=100).fit_transform(X)))
                tau, p_value = kendalltau(true_rank_indices, t_SNE_rank_indices)
                taus[i, 0] = np.abs(tau)

                spectral_lin_reg_rank_indices = spectral_lin_reg(X, k=2*d, N_MIN = d + 1)
                tau, p_value = kendalltau(true_rank_indices, spectral_lin_reg_rank_indices)
                taus[i, 1] = np.abs(tau)
            avg_taus = np.nanmean(taus, axis = 0)
            sd_taus = np.std(taus, axis = 0)
            print("Sample size = ", N_POINTS)
            print("d = ", d)
            print("sigma = ", NOISE_STD/np.sqrt(d))
            print("t-SNE: Mean =", avg_taus[0], "SD = ", sd_taus[0])
            print("STAGE: Mean =", avg_taus[1], "SD = ", sd_taus[1])

