import numpy as np
from scipy.stats import rankdata
from scipy.spatial.distance import cdist, pdist, euclidean, sqeuclidean, squareform # Added euclidean for get_neighbors_and_pca
from scipy.linalg import eigh # For symmetric matrix eigenvalues/vectors
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import lsqr, eigsh # Added eigsh for the except block
import mdso # https://github.com/antrec/mdso/tree/master (Recanati et al.)
from scipy.interpolate import interp1d


def get_neighbors_and_pca(point_idx, X, h = None, k = None, N_MIN = 3):
    """
    Finds neighbors of a point and performs PCA. Returns normalized tangent.
    """
    distances = cdist(X[[point_idx]], X)[0]
    if h is not None:
        neighbor_indices = np.where(distances <= h)[0]
    elif k is not None:
        neighbor_indices = np.argsort(distances)[1:(k+ 1)]
    if len(neighbor_indices) < N_MIN:
        return neighbor_indices, None, None, False
    neighbors = X[neighbor_indices]
    neighbor_mean = np.mean(neighbors, axis=0)
    centered_neighbors = neighbors - neighbor_mean
    try:
        if centered_neighbors.shape[0] >= centered_neighbors.shape[1]:
             _, _, Vh = np.linalg.svd(centered_neighbors, full_matrices=False)
             tangent_vector = Vh[0, :]
        else:
             return neighbor_indices, neighbor_mean, None, False
    except np.linalg.LinAlgError:
        return neighbor_indices, neighbor_mean, None, False
    return neighbor_indices, neighbor_mean, tangent_vector, True

#########
# STAGE #
#########

def spectral_lin_reg(X, h = None, k = None, N_MIN = 3):
    
    n_points, n_dims = X.shape
    # print("Calculating local PCA..."); start_time = time.time()
    local_pca_results = {}; valid_pca_indices = []
    for i in range(n_points):
        if h is not None:
            neigh_idx, mean_i, v_i, valid = get_neighbors_and_pca(i, X, h = h, N_MIN = N_MIN)
        elif k is not None:
            neigh_idx, mean_i, v_i, valid = get_neighbors_and_pca(i, X, k = k, N_MIN = N_MIN)
        if valid: local_pca_results[i] = {'neighbors': neigh_idx, 'mean': mean_i, 'tangent': v_i}; valid_pca_indices.append(i)
        else: local_pca_results[i] = None
    # print(f"PCA took {time.time() - start_time:.2f}s for {len(valid_pca_indices)} points.")
    if not valid_pca_indices: raise ValueError("PCA failed.")

    # --- 4. Spectral Orientation ---
    # print("Spectral orientation..."); start_time = time.time()
    W = lil_matrix((n_points, n_points), dtype=float)
    for i in valid_pca_indices:
        res_i = local_pca_results[i]; v_i = res_i['tangent']
        for j in res_i['neighbors']:
            if i == j: W[i, i] = 1.0; continue
            if j in local_pca_results and local_pca_results[j] is not None:
                res_j = local_pca_results[j]; v_j = res_j['tangent']
                W[i, j] = np.dot(v_i, v_j)
    W = (W + W.T) / 2.0; W = W.tocsr()
    try: eigenvalues, eigenvectors = eigh(W.toarray()); u_max = eigenvectors[:, -1]
    except np.linalg.LinAlgError: from scipy.sparse.linalg import eigsh; eigenvalues, eigenvectors = eigsh(W.tocsr(), k=1, which='LA'); u_max = eigenvectors[:, 0]
    signs_s = np.sign(u_max); signs_s[signs_s == 0] = 1
    oriented_tangents = {i: signs_s[i] * local_pca_results[i]['tangent'] for i in valid_pca_indices}
    # print(f"Orientation took {time.time() - start_time:.2f}s.")

    # --- 5. Build Linear System Az = T ---
    # print("Building linear system..."); start_time = time.time()
    rows, cols, data_A, data_T = [], [], [], []
    num_constraints = 0
    for i in valid_pca_indices:
        res_i = local_pca_results[i]; mean_i = res_i['mean']; v_i_oriented = oriented_tangents.get(i)
        if v_i_oriented is None: continue
        neighbors_i = res_i['neighbors']
        for j in neighbors_i:
            if i == j: continue
            target_val = np.dot(X[j] - mean_i, v_i_oriented)
            rows.append(num_constraints); cols.append(j); data_A.append(1.0)
            rows.append(num_constraints); cols.append(n_points + i); data_A.append(-1.0)
            data_T.append(target_val); num_constraints += 1
    A = csr_matrix((data_A, (rows, cols)), shape=(num_constraints, 2 * n_points)); T = np.array(data_T)
    # print(f"System built (M={num_constraints}, N={2*n_points}) in {time.time() - start_time:.2f}s.")

    # --- 6. Solve Linear System using LSQR ---
    # print("Solving linear system..."); start_time = time.time()
    solution_z, istop, itn, r1norm, r2norm, anorm, acond, arnorm, xnorm, var = lsqr(A, T, show=False, iter_lim=min(A.shape)*2)
    # print(f"LSQR took {time.time() - start_time:.2f}s. istop={istop}, itn={itn}")
    y = solution_z[:n_points]; y = y - np.mean(y)

    # --- 7. Get Final Order & Evaluate ---
    return(rankdata(y))

####################################
# Permutation using Fiedler Vector #
####################################

# basic similarity function using Euclidian distance, can be modified
def gaussian_kernel(x, y, sigma = .25):
    return (1/(2*np.pi *  sigma**2))*np.exp(-sqeuclidean(x, y)/2*sigma**2)

def inverse_distance(x, y):
    return(1/(1 + euclidean(x, y)))

# Find Fiedler vector (eigenvector corresponding to smallest nonzero eigenval)
def fiedler_permutation(X, similarity_measure):
    sim_mat = squareform(pdist(X, similarity_measure))
    D = np.diag(np.sum(sim_mat, axis=0))
    L = D - sim_mat
    eigenvals, eigenvecs = np.linalg.eig(L)
    fiedler_vec = eigenvecs[:, np.argsort(eigenvals)[2]]
    return(fiedler_vec, np.argsort(np.argsort(fiedler_vec)))

#######################################
# Spectral Ordering (Recanati et al.) #
#######################################

def spectral_ordering(X, similarity_measure):
    sim_mat = squareform(pdist(X, similarity_measure))
    spec_ord = mdso.SpectralOrdering()
    return(spec_ord.fit_transform(sim_mat))

###################
# Data Generation #
###################

def generate_sin_data(n, noise_std, angle):
    thetas_true = np.linspace(0, 2 * np.pi, n)
    x_coords = thetas_true
    y_coords = 2 * np.sin(thetas_true)
    manifold_points = np.vstack((x_coords, y_coords)).T
    noise = np.random.normal(0, noise_std, size=manifold_points.shape)
    noisy_points = manifold_points + noise
    true_order_indices = np.arange(n)
    rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)],
                                [np.sin(angle),  np.cos(angle)]])
    rotated_points = noisy_points @ rotation_matrix.T
    # print(f"Data generated (y=2sin(x)) and rotated by {np.degrees(angle):.1f} degrees.")
    return rotated_points, thetas_true, true_order_indices

def generate_spiral_data(n,
                   kappa_end,
                   *,
                   length=1.0,
                   start=(0.0, 0.0),
                   heading=0.0,
                   noise_std = 0.05):
    if n < 2:
        raise ValueError("Need at least two points.")
    if length <= 0:
        raise ValueError("length must be positive.")

    # Arc‑length positions
    s = np.linspace(0.0, length, n)
    ds = length / (n - 1)

    # Linear curvature profile: κ(s) = (kappa_end / length) * s
    a = kappa_end / length                    # slope of curvature
    theta = 0.5 * a * s**2                    # heading angle at each s

    # Integrate to get coordinates (simple trapezoidal rule)
    cos_t = np.cos(theta)
    sin_t = np.sin(theta)

    x_rel = np.zeros_like(s)
    y_rel = np.zeros_like(s)

    x_rel[1:] = np.cumsum((cos_t[:-1] + cos_t[1:]) * 0.5 * ds)
    y_rel[1:] = np.cumsum((sin_t[:-1] + sin_t[1:]) * 0.5 * ds)

    # Rotate by initial heading and translate to start point
    c, s_h = np.cos(heading), np.sin(heading)
    x_full = c * x_rel - s_h * y_rel + start[0]
    y_full = s_h * x_rel + c * y_rel + start[1]

    manifold_points = np.column_stack((x_full, y_full))
    noise = np.random.normal(0, noise_std, size=manifold_points.shape)
    noisy_points = manifold_points + noise
    true_order_indices = np.arange(n)

    return noisy_points, s, true_order_indices

def random_fourier_curve(n_pts=2000, d=10, K=8, alpha=1.5, noise=0.02, arclength=True, shuffle=True, endpoints="random", rng=None):
    if rng is None:
        rng = np.random.default_rng()

    t = np.linspace(0.0, 1.0, n_pts)

    # ----------------------------------------------------------------------
    # 1. core oscillatory signal (same as before, still *periodic*)
    Y = np.zeros((n_pts, d))
    for k in range(1, K + 1):
        sigma = k ** (-alpha)
        A = rng.normal(scale=sigma, size=d)
        B = rng.normal(scale=sigma, size=d)
        Y += np.outer(np.sin(2 * np.pi * k * t), A) \
           + np.outer(np.cos(2 * np.pi * k * t), B)

    # ----------------------------------------------------------------------
    # 2. wrap it into either a CLOSED or an OPEN trajectory
    if endpoints == "closed":
        curve = Y

    else:                      # "random"   or (p0, p1) supplied
        if isinstance(endpoints, str) and endpoints == "random":
            p0 = rng.normal(size=d)
            p1 = rng.normal(size=d)
        else:                  # explicit tuple / list / ndarray
            try:
                p0, p1 = endpoints
            except Exception as e:
                raise ValueError(
                    "'endpoints' must be 'closed', 'random', "
                    "or a (p0, p1) pair of length-d vectors") from e
        # wiggle envelope that vanishes at both ends
        w = np.sin(np.pi * t)[:, None]          # (n_pts,1)
        # blend a straight line with the bounded oscillations
        curve = p0 + (p1 - p0) * t[:, None] + w * Y

    # ----------------------------------------------------------------------
    # 3. optional arc-length re-parameterisation
    if arclength and interp1d is not None:
        seglen = np.linalg.norm(np.diff(curve, axis=0), axis=1)
        s = np.hstack(([0], np.cumsum(seglen)));  s /= s[-1]
        curve = np.column_stack([
            interp1d(s, curve[:, j], kind='cubic',
                     assume_sorted=True)(t)
            for j in range(d)
        ])

    # ----------------------------------------------------------------------
    # 4. add isotropic noise
    noisy_curve = curve + noise * rng.standard_normal(size=curve.shape)

    # ----------------------------------------------------------------------
    # 5. (optionally) scramble rows and build `order`
    if shuffle:
        perm = rng.permutation(n_pts)
        pts   = curve[perm]
        noisy_pts = noisy_curve[perm]
    else:
        pts   = curve
        noisy_pts = noisy_curve 
        order = np.arange(n_pts)

    return pts, noisy_pts, perm