import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
from scipy.stats import kendalltau
from scipy.spatial.distance import pdist, sqeuclidean, euclidean, squareform, cdist
from scipy.linalg import eigh # For symmetric matrix eigenvalues/vectors
from scipy.sparse import lil_matrix, csr_matrix, vstack as sp_vstack
from scipy.sparse.linalg import lsqr # Solver for sparse least squares
import time
import umap
from sklearn.manifold import TSNE

def get_neighbors_and_pca(point_idx, X, h, N_min):
    """
    Finds neighbors of a point and performs PCA. Returns normalized tangent.
    """
    distances = cdist(X[[point_idx]], X)[0]
    neighbor_indices = np.where(distances <= h)[0]
    if len(neighbor_indices) < N_min:
        return neighbor_indices, None, None, False
    neighbors = X[neighbor_indices]
    neighbor_mean = np.mean(neighbors, axis=0)
    centered_neighbors = neighbors - neighbor_mean
    try:
        if centered_neighbors.shape[0] >= centered_neighbors.shape[1]:
             _, _, Vh = np.linalg.svd(centered_neighbors, full_matrices=False)
             tangent_vector = Vh[0, :]
        else:
             return neighbor_indices, neighbor_mean, None, False
    except np.linalg.LinAlgError:
        return neighbor_indices, neighbor_mean, None, False
    return neighbor_indices, neighbor_mean, tangent_vector, True

def spectral_lin_reg(X, H_RADIUS, N_MIN):
    
    n_points, n_dims = X.shape
    #print("Calculating local PCA..."); start_time = time.time()
    local_pca_results = {}; valid_pca_indices = []
    for i in range(n_points):
        neigh_idx, mean_i, v_i, valid = get_neighbors_and_pca(i, X, H_RADIUS, N_MIN)
        if valid: local_pca_results[i] = {'neighbors': neigh_idx, 'mean': mean_i, 'tangent': v_i}; valid_pca_indices.append(i)
        else: local_pca_results[i] = None
    #print(f"PCA took {time.time() - start_time:.2f}s for {len(valid_pca_indices)} points.")
    if not valid_pca_indices: raise ValueError("PCA failed.")

    # --- 4. Spectral Orientation ---
    #print("Spectral orientation..."); start_time = time.time()
    A_adj = np.empty((n_points, n_points))
    W = lil_matrix((n_points, n_points), dtype=float)
    for i in valid_pca_indices:
        res_i = local_pca_results[i]; v_i = res_i['tangent']
        for j in res_i['neighbors']:
            A_adj[i, j] = 1
            if i == j: W[i, i] = 1.0; continue
            if j in local_pca_results and local_pca_results[j] is not None:
                res_j = local_pca_results[j]; v_j = res_j['tangent']
                W[i, j] = np.dot(v_i, v_j)
    W = (W + W.T) / 2.0; W = W.tocsr()

    try: eigenvalues, eigenvectors = eigh(W.toarray()); u_max = eigenvectors[:, -1]
    except np.linalg.LinAlgError: from scipy.sparse.linalg import eigsh; eigenvalues, eigenvectors = eigsh(W.tocsr(), k=1, which='LA'); u_max = eigenvectors[:, 0]
    signs_s = np.sign(u_max); signs_s[signs_s == 0] = 1
    oriented_tangents = {i: signs_s[i] * local_pca_results[i]['tangent'] for i in valid_pca_indices}
    #print(f"Orientation took {time.time() - start_time:.2f}s.")

    D = np.diag(np.sum(A_adj, axis = 0))
    L = D - A_adj
    eigval, _ = np.linalg.eigh(L)
    eig2 = eigval[2]

    # --- 5. Build Linear System Az = T ---
    #print("Building linear system..."); start_time = time.time()
    rows, cols, data_A, data_T = [], [], [], []
    num_constraints = 0
    for i in valid_pca_indices:
        res_i = local_pca_results[i]; mean_i = res_i['mean']; v_i_oriented = oriented_tangents.get(i)
        if v_i_oriented is None: continue
        neighbors_i = res_i['neighbors']
        for j in neighbors_i:
            if i == j: continue
            target_val = np.dot(X[j] - mean_i, v_i_oriented)
            rows.append(num_constraints); cols.append(j); data_A.append(1.0)
            rows.append(num_constraints); cols.append(n_points + i); data_A.append(-1.0)
            data_T.append(target_val); num_constraints += 1
    A = csr_matrix((data_A, (rows, cols)), shape=(num_constraints, 2 * n_points)); T = np.array(data_T)
    #print(f"System built (M={num_constraints}, N={2*n_points}) in {time.time() - start_time:.2f}s.")

    # --- 6. Solve Linear System using LSQR ---
    #print("Solving linear system..."); start_time = time.time()
    solution_z, istop, itn, r1norm, r2norm, anorm, acond, arnorm, xnorm, var = lsqr(A, T, show=False, iter_lim=min(A.shape)*2)
    #print(f"LSQR took {time.time() - start_time:.2f}s. istop={istop}, itn={itn}")
    y = solution_z[:n_points]; y = y - np.mean(y)

    # --- 7. Get Final Order & Evaluate ---
    return(np.argsort(y), eig2)

def euler_spiral(n,
                   kappa_end,
                   *,
                   length=1.0,
                   start=(0.0, 0.0),
                   heading=0.0,
                   noise_std = 0.05):
    """
    Generate `n` points along an Euler‑spiral (clothoid) segment.

    Parameters
    ----------
    n : int
        Number of output points (≥2), spaced uniformly in arc length.
    kappa_end : float
        Curvature at the end of the segment (signed, in 1/units).
        Positive → left‑hand spiral, negative → right‑hand spiral.
    length : float, optional
        Total arc length of the segment (default 1).
    start : tuple[float, float], optional
        (x, y) coordinates of the starting point (default (0, 0)).
    heading : float, optional
        Initial heading angle (radians, measured from +x, CCW positive).

    Returns
    -------
    pts : ndarray, shape (n, 2)
        Array of (x, y) coordinates.
    """
    if n < 2:
        raise ValueError("Need at least two points.")
    if length <= 0:
        raise ValueError("length must be positive.")

    # Arc‑length positions
    s = np.linspace(0.0, length, n)
    ds = length / (n - 1)

    # Linear curvature profile: κ(s) = (kappa_end / length) * s
    a = kappa_end / length                    # slope of curvature
    theta = 0.5 * a * s**2                    # heading angle at each s

    # Integrate to get coordinates (simple trapezoidal rule)
    cos_t = np.cos(theta)
    sin_t = np.sin(theta)

    x_rel = np.zeros_like(s)
    y_rel = np.zeros_like(s)

    x_rel[1:] = np.cumsum((cos_t[:-1] + cos_t[1:]) * 0.5 * ds)
    y_rel[1:] = np.cumsum((sin_t[:-1] + sin_t[1:]) * 0.5 * ds)

    # Rotate by initial heading and translate to start point
    c, s_h = np.cos(heading), np.sin(heading)
    x_full = c * x_rel - s_h * y_rel + start[0]
    y_full = s_h * x_rel + c * y_rel + start[1]

    manifold_points = np.column_stack((x_full, y_full))
    noise = np.random.normal(0, noise_std, size=manifold_points.shape)
    noisy_points = manifold_points + noise
    true_order_indices = np.arange(n)

    return noisy_points, s, true_order_indices



N_MIN = 3            # Minimum neighbors for PCA (d+1 for d=2)
N_RUNS = 10
H_RADIUS_LIST = np.array(.1*np.arange(1, 6))
N_POINTS_LIST = np.array([500, 1000])
KAPPA_LIST = np.array([10, 15])
NOISE_STD_LIST = np.array(0.01*np.arange(1, 6))
taus_avg_stage = {NOISE_STD: {h: np.empty((len(N_POINTS_LIST), len(KAPPA_LIST))) for h in H_RADIUS_LIST} for NOISE_STD in NOISE_STD_LIST}
second_eig_L_avg_stage = {NOISE_STD: {h: np.empty((len(N_POINTS_LIST), len(KAPPA_LIST))) for h in H_RADIUS_LIST} for NOISE_STD in NOISE_STD_LIST}

taus_avg_tsne = {NOISE_STD: {h: np.empty((len(N_POINTS_LIST), len(KAPPA_LIST))) for h in H_RADIUS_LIST} for NOISE_STD in NOISE_STD_LIST}
second_eig_L_avg_tsne = {NOISE_STD: {h: np.empty((len(N_POINTS_LIST), len(KAPPA_LIST))) for h in H_RADIUS_LIST} for NOISE_STD in NOISE_STD_LIST}


for H_RADIUS in H_RADIUS_LIST:
    for NOISE_STD in NOISE_STD_LIST:
        print("Noise = ", NOISE_STD)
        print("h = ", H_RADIUS)
        np.random.seed(100)
        start = time.time()
        for i in np.arange(len(N_POINTS_LIST)):
            N_POINTS = N_POINTS_LIST[i]
            for j in np.arange(len(KAPPA_LIST)):
                KAPPA = KAPPA_LIST[j]
                print("N = ", N_POINTS)
                print("kappa = ", KAPPA)
                taus_stage = np.empty(N_RUNS)
                second_eig_L_stage = np.empty(N_RUNS)

                taus_tsne = np.empty(N_RUNS)
                second_eig_L_tsne = np.empty(N_RUNS)
                for k in np.arange(N_RUNS):
                    X, _, true_order_indices = euler_spiral(n = N_POINTS, kappa_end = KAPPA, noise_std = NOISE_STD)
                    
                    spectral_lin_reg_order_indices, eig2 = spectral_lin_reg(X, H_RADIUS, N_MIN)
                    tau, p_value = kendalltau(true_order_indices, spectral_lin_reg_order_indices)
                    taus_stage[k] = abs(tau)
                    second_eig_L_stage[k] = eig2/N_POINTS

                    t_SNE_order_indices = np.argsort(np.ndarray.flatten(TSNE(n_components = 1).fit_transform(X)))
                    tau, p_value = kendalltau(true_order_indices, t_SNE_order_indices)
                    taus_tsne[k] = abs(tau)
                    second_eig_L_tsne[k] = eig2/N_POINTS

                taus_avg_stage[NOISE_STD][H_RADIUS][i, j] = np.mean(taus_stage)
                second_eig_L_avg_stage[NOISE_STD][H_RADIUS][i, j] = np.mean(second_eig_L_stage)

                taus_avg_tsne[NOISE_STD][H_RADIUS][i, j] = np.mean(taus_tsne)
                second_eig_L_avg_tsne[NOISE_STD][H_RADIUS][i, j] = np.mean(second_eig_L_tsne)
        print("Time taken:", time.time() - start)
        print()
        print("Avg. Kendall's Tau for STAGE:")
        print(taus_avg_stage[NOISE_STD][H_RADIUS])
        print("Avg. 2nd Largest Eigenvalues for STAGE")
        print(second_eig_L_avg_stage[NOISE_STD][H_RADIUS])
        # taus_avg_h = taus_avg_stage[NOISE_STD][H_RADIUS]
        # second_eig_L_h = second_eig_L_avg_stage[NOISE_STD][H_RADIUS]

        print("Time taken:", time.time() - start)
        print("Avg. Kendall's Tau for t-SNE:")
        print(taus_avg_tsne[NOISE_STD][H_RADIUS])
        # taus_avg_h = taus_avg_tsne[NOISE_STD][H_RADIUS]
        # second_eig_L_h = second_eig_L_avg_tsne[NOISE_STD][H_RADIUS]


    #     fig, ax = plt.subplots()
    #     im = ax.imshow(taus_avg_h, cmap = "cividis", interpolation="nearest")

    # # Show all ticks and label them with the respective list entries
    #     ax.set_xticks(range(len(KAPPA_LIST)), labels=KAPPA_LIST,
    #             rotation=45, ha="right", rotation_mode="anchor")
    #     ax.set_yticks(range(len(N_POINTS_LIST)), labels=N_POINTS_LIST)
    #     for i in range(len(N_POINTS_LIST)):
    #         for j in range(len(KAPPA_LIST)):
    #             text = ax.text(j, i, round(taus_avg_h[i, j], 3),
    #                         ha="center", va="center", color="w")
    #     ax.set_title("Kendall Tau, h = {0}, noise = {1}".format(round(H_RADIUS, 1), round(NOISE_STD, 2)))
    #     fig.tight_layout()
    #     plt.savefig("Kendall Tau, h = {0}, noise = {1}_v2.png".format(round(H_RADIUS, 1), round(NOISE_STD, 2)))

    #     fig, ax = plt.subplots()
    #     im = ax.imshow(second_eig_L_h, cmap = "cividis", interpolation="nearest")

    # # Show all ticks and label them with the respective list entries
    #     ax.set_xticks(range(len(KAPPA_LIST)), labels=KAPPA_LIST,
    #             rotation=45, ha="right", rotation_mode="anchor")
    #     ax.set_yticks(range(len(N_POINTS_LIST)), labels=N_POINTS_LIST)
    #     for i in range(len(N_POINTS_LIST)):
    #         for j in range(len(KAPPA_LIST)):
    #             text = ax.text(j, i, round(second_eig_L_h[i, j], 3),
    #                         ha="center", va="center", color="w")
    #     ax.set_title("2nd Eigeivalue of L (scaled by sample size), h = {0}, noise = {1}".format(round(H_RADIUS, 1), round(NOISE_STD, 2)))
    #     fig.tight_layout()
    #     plt.savefig("2nd Eigeivalue of L, h = {0}, noise = {1}_v2.png".format(round(H_RADIUS, 1), round(NOISE_STD, 2)))

    # Comparisons with 1-D UMAP, t-SNE, Recanati

