import numpy as np
from scipy.stats import kendalltau
import umap
from sklearn.manifold import TSNE
import time
import warnings
from util_functions import spectral_lin_reg, fiedler_permutation, spectral_ordering, generate_sin_data, gaussian_kernel


#########################
# Comparison of Methods #
#########################
warnings.filterwarnings("ignore", category=FutureWarning)

np.random.seed(100)

N_POINTS_LIST = np.array([500, 1000, 3000])
NOISE_STD_LIST = np.array(0.05*np.arange(1, 11))
ROTATION_ANGLE = -np.pi / 3
H_RADIUS = 0.5       # Fixed neighborhood radius
N_MIN = 3            # Minimum neighbors for PCA (d+1 for d=2)
N_RUNS = 10

results = {}
for N_POINTS in N_POINTS_LIST:
    results[N_POINTS] = {}
    for NOISE_STD in NOISE_STD_LIST:
        print("N = ", N_POINTS)
        print("noise = ", NOISE_STD)
        start = time.time()
        algs = ["fiedler", "UMAP", "t_SNE", "spectral_lin_reg", "spectral_order"]
        taus = {"fiedler": np.empty(N_RUNS), "UMAP": np.empty(N_RUNS), "t_SNE": np.empty(N_RUNS), "spectral_lin_reg": np.empty(N_RUNS), "spectral_order": np.empty(N_RUNS)}
        results[N_POINTS][NOISE_STD] = {}
        outputs = {alg: np.empty((N_RUNS, N_POINTS)) for alg in algs}
        outputs["true_labels"] = np.empty((N_RUNS, N_POINTS))
        filename = "{0}_N_{1}_NOISE_{2}_v3.txt"
        for i in np.arange(N_RUNS):
            if i % 10 == 0: print(i)
            X, thetas_true, true_order_indices = generate_sin_data(N_POINTS, NOISE_STD, ROTATION_ANGLE)
            n_points, n_dims = X.shape
            outputs["true_labels"][i, :] = true_order_indices
            
            umap_obj = umap.UMAP(n_components=1)
            umap_result = np.ndarray.flatten(umap_obj.fit_transform(X))
            UMAP_order_indices = np.argsort(umap_result)
            tau, p_value = kendalltau(true_order_indices, UMAP_order_indices)
            taus["UMAP"][i] = np.abs(tau)
            outputs["UMAP"][i,:] = UMAP_order_indices

            _, fiedler_order_indices = fiedler_permutation(X, gaussian_kernel)
            tau, p_value = kendalltau(true_order_indices, fiedler_order_indices)
            taus["fiedler"][i] = np.abs(tau)
            outputs["fiedler"][i,:] = fiedler_order_indices

            t_SNE_order_indices = np.argsort(np.ndarray.flatten(TSNE(n_components = 1).fit_transform(X)))
            tau, p_value = kendalltau(true_order_indices, t_SNE_order_indices)
            taus['t_SNE'][i] = np.abs(tau)
            outputs["t_SNE"][i,:] = t_SNE_order_indices

            spectral_lin_reg_order_indices = spectral_lin_reg(X, h = H_RADIUS, N_MIN = N_MIN)
            tau, p_value = kendalltau(true_order_indices, spectral_lin_reg_order_indices)
            taus['spectral_lin_reg'][i] = np.abs(tau)
            outputs["spectral_lin_reg"][i,:] = spectral_lin_reg_order_indices

            spectral_order_order_indices = spectral_ordering(X, gaussian_kernel)
            tau, p_value = kendalltau(true_order_indices, spectral_order_order_indices)
            taus['spectral_order'][i] = np.abs(tau)
            outputs["spectral_order"][i,:] = spectral_order_order_indices
        
        for alg in algs:
            results[N_POINTS][NOISE_STD][alg] = {"mean": np.mean(taus[alg]), "sd": np.std(taus[alg]), "min": np.min(taus[alg])}
            np.savetxt(filename.format(alg, N_POINTS, NOISE_STD), outputs[alg])
        np.savetxt(filename.format("true_labels", N_POINTS, NOISE_STD), outputs["true_labels"])
        end = time.time()
        print("Time taken: ", end - start)

with open('results.txt', 'w') as f:
    print(results, file=f)