import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
from scipy.stats import kendalltau
from scipy.spatial.distance import pdist, sqeuclidean, euclidean, squareform, cdist
from scipy.linalg import eigh # For symmetric matrix eigenvalues/vectors
from scipy.sparse import lil_matrix, csr_matrix, vstack as sp_vstack
from scipy.sparse.linalg import lsqr # Solver for sparse least squares
import umap
import mdso
from sklearn.manifold import TSNE
from sklearn.decomposition import FastICA, PCA
import time
import warnings
import pathlib, urllib.request, gzip, pandas as pd
import GEOparse
import scanpy as sc
import pyreadr
from mpl_toolkits import mplot3d


def gaussian_kernel(x, y, sigma = 20):
    return (1/(2*np.pi *  sigma**2))*np.exp(-sqeuclidean(x, y)/2*sigma**2)
def inverse_distance(x, y):
    return(1/(1 + euclidean(x, y)))

def get_neighbors_and_pca(point_idx, X, h, N_min):
    """
    Finds neighbors of a point and performs PCA. Returns normalized tangent.
    """
    distances = cdist(X[[point_idx]], X)[0]
    neighbor_indices = np.where(distances <= h)[0]
    if len(neighbor_indices) < N_min:
        return neighbor_indices, None, None, False
    neighbors = X[neighbor_indices]
    neighbor_mean = np.mean(neighbors, axis=0)
    centered_neighbors = neighbors - neighbor_mean
    try:
        if centered_neighbors.shape[0] >= centered_neighbors.shape[1]:
             _, _, Vh = np.linalg.svd(centered_neighbors, full_matrices=False)
             tangent_vector = Vh[0, :]
        else:
             return neighbor_indices, neighbor_mean, None, False
    except np.linalg.LinAlgError:
        return neighbor_indices, neighbor_mean, None, False
    return neighbor_indices, neighbor_mean, tangent_vector, True

#########
# STAGE #
#########
def spectral_lin_reg(X, H_RADIUS, N_MIN):
    
    n_points, n_dims = X.shape
    # print("Calculating local PCA..."); start_time = time.time()
    local_pca_results = {}; valid_pca_indices = []
    for i in range(n_points):
        neigh_idx, mean_i, v_i, valid = get_neighbors_and_pca(i, X, H_RADIUS, N_MIN)
        if valid: local_pca_results[i] = {'neighbors': neigh_idx, 'mean': mean_i, 'tangent': v_i}; valid_pca_indices.append(i)
        else: local_pca_results[i] = None
    # print(f"PCA took {time.time() - start_time:.2f}s for {len(valid_pca_indices)} points.")
    if not valid_pca_indices: raise ValueError("PCA failed.")

    # --- 4. Spectral Orientation ---
    # print("Spectral orientation..."); start_time = time.time()
    W = lil_matrix((n_points, n_points), dtype=float)
    for i in valid_pca_indices:
        res_i = local_pca_results[i]; v_i = res_i['tangent']
        for j in res_i['neighbors']:
            if i == j: W[i, i] = 1.0; continue
            if j in local_pca_results and local_pca_results[j] is not None:
                res_j = local_pca_results[j]; v_j = res_j['tangent']
                W[i, j] = np.dot(v_i, v_j)
    W = (W + W.T) / 2.0; W = W.tocsr()
    try: eigenvalues, eigenvectors = eigh(W.toarray()); u_max = eigenvectors[:, -1]
    except np.linalg.LinAlgError: from scipy.sparse.linalg import eigsh; eigenvalues, eigenvectors = eigsh(W.tocsr(), k=1, which='LA'); u_max = eigenvectors[:, 0]
    signs_s = np.sign(u_max); signs_s[signs_s == 0] = 1
    oriented_tangents = {i: signs_s[i] * local_pca_results[i]['tangent'] for i in valid_pca_indices}
    # print(f"Orientation took {time.time() - start_time:.2f}s.")

    # --- 5. Build Linear System Az = T ---
    # print("Building linear system..."); start_time = time.time()
    rows, cols, data_A, data_T = [], [], [], []
    num_constraints = 0
    for i in valid_pca_indices:
        res_i = local_pca_results[i]; mean_i = res_i['mean']; v_i_oriented = oriented_tangents.get(i)
        if v_i_oriented is None: continue
        neighbors_i = res_i['neighbors']
        for j in neighbors_i:
            if i == j: continue
            target_val = np.dot(X[j] - mean_i, v_i_oriented)
            rows.append(num_constraints); cols.append(j); data_A.append(1.0)
            rows.append(num_constraints); cols.append(n_points + i); data_A.append(-1.0)
            data_T.append(target_val); num_constraints += 1
    A = csr_matrix((data_A, (rows, cols)), shape=(num_constraints, 2 * n_points)); T = np.array(data_T)
    # print(f"System built (M={num_constraints}, N={2*n_points}) in {time.time() - start_time:.2f}s.")

    # --- 6. Solve Linear System using LSQR ---
    # print("Solving linear system..."); start_time = time.time()
    solution_z, istop, itn, r1norm, r2norm, anorm, acond, arnorm, xnorm, var = lsqr(A, T, show=False, iter_lim=min(A.shape)*2)
    # print(f"LSQR took {time.time() - start_time:.2f}s. istop={istop}, itn={itn}")
    y = solution_z[:n_points]; y = y - np.mean(y)

    # --- 7. Get Final Order & Evaluate ---
    return(np.argsort(y))

def get_neighbors_and_pca_adaptive(point_idx, X, h, N_min):
    """
    Finds neighbors of a point within radius h, performs PCA, and calculates eigengap.

    Args:
        point_idx (int): Index of the current point in X.
        X (np.ndarray): The dataset (n_points, n_dims).
        h (float): The radius for neighbor search.
        N_min (int): Minimum number of neighbors (including the point itself)
                     required for a valid PCA.

    Returns:
        tuple: (neighbor_indices, neighbor_mean, tangent_vector, valid_pca, eigengap)
            - neighbor_indices (np.ndarray): Indices of the neighbors.
            - neighbor_mean (np.ndarray or None): Mean of the neighbors.
            - tangent_vector (np.ndarray or None): First principal component (tangent).
            - valid_pca (bool): True if PCA was successful and conditions met.
            - eigengap (float): The calculated eigengap (s[0]-s[1] from SVD).
                                Returns -1.0 if PCA invalid or eigengap not applicable.
    """
    # Calculate distances from the current point to all other points
    distances = cdist(X[[point_idx]], X)[0]
    # Find neighbors within the radius h
    neighbor_indices = np.where(distances <= h)[0]

    # Check if the minimum number of neighbors is met
    if len(neighbor_indices) < N_min:
        return neighbor_indices, None, None, False, -1.0

    neighbors = X[neighbor_indices]
    neighbor_mean = np.mean(neighbors, axis=0)
    centered_neighbors = neighbors - neighbor_mean

    # Perform PCA if there are enough points and dimensions
    # PCA is typically robust if num_samples >= num_dimensions
    if centered_neighbors.shape[0] >= centered_neighbors.shape[1] and centered_neighbors.shape[1] > 0:
        try:
            # Perform SVD: X = U * S * Vh
            # s contains the singular values. Vh contains the principal components as rows.
            # Using centered_neighbors for local PCA
            _ , s, Vh = np.linalg.svd(centered_neighbors, full_matrices=False)

            if len(s) == 0: # Should not happen if shape[1] > 0 and SVD succeeds
                return neighbor_indices, neighbor_mean, None, False, -1.0

            tangent_vector = Vh[0, :]  # First principal component (dominant direction)

            eigengap = -1.0
            if len(s) > 1:
                # Eigengap: difference between the first and second singular values.
                # A larger gap suggests a more dominant first principal component.
                eigengap = (s[0] - s[1])/s[0]
            elif len(s) == 1:
                # If only one singular value (e.g., data is 1D or effectively 1D after centering),
                # the "gap" can be considered the singular value itself, as s[1] would be 0.
                eigengap = s[0]
            # If len(s) == 0, eigengap remains -1.0

            return neighbor_indices, neighbor_mean, tangent_vector, True, eigengap
        except np.linalg.LinAlgError:
            # SVD failed (e.g., matrix contains NaN or Inf)
            return neighbor_indices, neighbor_mean, None, False, -1.0
    else:
        # Not enough points relative to dimensions for robust PCA, or no dimensions to analyze
        return neighbor_indices, neighbor_mean, None, False, -1.0

#########
# STAGE #
#########
def spectral_lin_reg_adaptive(X, h_radii_list, N_MIN):
    """
    Performs spectral linear regression.
    Adaptively chooses the best radius h for local PCA for each point from h_radii_list
    based on maximizing the local eigengap.

    Args:
        X (np.ndarray): The input data (n_points, n_dims).
        h_radii_list (list of float): A list of H_RADIUS values to try for local PCA.
        N_MIN (int): Minimum number of neighbors for local PCA.

    Returns:
        np.ndarray: Sorted indices of points, representing the learned order.
    """
    n_points, n_dims = X.shape
    # print("Calculating local PCA with adaptive radius..."); start_time = time.time()
    local_pca_results = {}  # Stores PCA results for points with valid PCA
    valid_pca_indices = []  # List of indices for which PCA was successful

    for i in range(n_points):
        best_h_for_point_i = None
        max_eigengap_for_point_i = -float('inf') # Initialize with negative infinity
        current_best_pca_details = None # Stores (neigh_idx, mean_i, v_i) for the best h

        for h_candidate in h_radii_list:
            # Get PCA results and eigengap for the current point and h_candidate
            neigh_idx, mean_i, v_i, valid, eigengap_val = get_neighbors_and_pca_adaptive(i, X, h_candidate, N_MIN)

            if valid and eigengap_val > max_eigengap_for_point_i:
                max_eigengap_for_point_i = eigengap_val
                best_h_for_point_i = h_candidate # Keep track of the best h (optional)
                # Store the essential PCA results for this best h
                current_best_pca_details = {
                    'neighbors': neigh_idx,
                    'mean': mean_i,
                    'tangent': v_i,
                    'selected_h': best_h_for_point_i, # For reference
                    'achieved_eigengap': max_eigengap_for_point_i # For reference
                }
        
        if current_best_pca_details:
            local_pca_results[i] = current_best_pca_details
            valid_pca_indices.append(i)
        else:
            # If no valid PCA could be found for point i with any h_candidate
            local_pca_results[i] = None
            
    # print(f"Adaptive PCA took {time.time() - start_time:.2f}s for {len(valid_pca_indices)} points.")
    if not valid_pca_indices:
        # If no points have valid PCA, we cannot proceed.
        raise ValueError("PCA failed for all points or no suitable radius found. Try adjusting h_radii_list or N_MIN.")

    # --- 4. Spectral Orientation ---
    # Construct the affinity matrix W based on dot products of tangent vectors
    # print("Spectral orientation..."); start_time = time.time()
    W = lil_matrix((n_points, n_points), dtype=float) # Use lil_matrix for efficient construction
    
    for i in valid_pca_indices:
        res_i = local_pca_results[i]
        if res_i is None: continue # Should not happen if i is in valid_pca_indices
        v_i = res_i['tangent']
        
        for j in res_i['neighbors']: # Iterate over neighbors found using the optimal h for point i
            if i == j:
                W[i, i] = 1.0 # Self-similarity
                continue
            
            # Check if point j also has a valid PCA result
            if j in local_pca_results and local_pca_results[j] is not None:
                res_j = local_pca_results[j]
                v_j = res_j['tangent']
                # Affinity is the absolute dot product (or squared dot product) of their tangents
                # Original code uses dot(v_i, v_j). Let's stick to that.
                # Using abs or square can make it robust to unoriented tangents before global orientation.
                # However, the global orientation step handles signs.
                W[i, j] = np.dot(v_i, v_j)
                # W[i, j] = abs(np.dot(v_i, v_j)) # Alternative: if tangents are not yet oriented

    # Symmetrize the matrix: W = (W + W^T) / 2
    W = (W + W.T) / 2.0 
    W = W.tocsr() # Convert to CSR format for efficient eigenvalue computation

    # Compute the principal eigenvector of W to determine global orientations
    try:
        # For dense W, or if conversion to dense is acceptable
        eigenvalues, eigenvectors = eigh(W.toarray()) # eigh for symmetric matrices
        u_max = eigenvectors[:, -1] # Eigenvector corresponding to the largest eigenvalue
    except (np.linalg.LinAlgError, ValueError): # ValueError if k >= N for eigsh
        # If W is large and sparse, or eigh fails, use eigsh
        # k=1 asks for the largest algebraic eigenvalue ('LA')
        # Ensure W has some non-zero entries for eigsh
        if W.nnz > 0:
            eigenvalues, eigenvectors = eigsh(W, k=1, which='LA', tol=1e-4) # May need to adjust tol
            u_max = eigenvectors[:, 0]
        else:
            # Handle case of all-zero W matrix if it occurs
            u_max = np.ones(n_points)


    # Determine signs for orienting tangent vectors
    signs_s = np.sign(u_max)
    signs_s[signs_s == 0] = 1 # Avoid zero signs, default to positive

    oriented_tangents = {}
    for i in valid_pca_indices:
        if local_pca_results[i] is not None:
            oriented_tangents[i] = signs_s[i] * local_pca_results[i]['tangent']
    # print(f"Orientation took {time.time() - start_time:.2f}s.")

    # --- 5. Build Linear System Az = T ---
    # This part aims to find scalar projections y_i for each point.
    # print("Building linear system..."); start_time = time.time()
    rows, cols, data_A = [], [], []
    data_T = []
    num_constraints = 0

    for i in valid_pca_indices:
        res_i = local_pca_results[i]
        if res_i is None: continue
        mean_i = res_i['mean']
        v_i_oriented = oriented_tangents.get(i)
        if v_i_oriented is None: continue # Should not happen if i is in valid_pca_indices

        neighbors_i = res_i['neighbors'] # Neighbors determined by the optimal h for point i
        for j in neighbors_i:
            if i == j: continue # Skip self-loops for constraints

            # The constraint is: y_j - y_i approx (X_j - mean_i) . v_i_oriented
            # This can be written as: 1*y_j - 1*y_i = target_val
            # In the provided code, it seems to be: y_j - c_i = target_val
            # where c_i is an auxiliary variable per point i.
            # Let's follow the original formulation:
            # A has columns for y_0...y_{n-1} and c_0...c_{n-1}
            # Constraint: y_j - c_i = (X[j] - mean_i) . v_i_oriented

            target_val = np.dot(X[j] - mean_i, v_i_oriented)
            
            # Coefficient for y_j is 1
            rows.append(num_constraints)
            cols.append(j) # Index for y_j
            data_A.append(1.0)

            # Coefficient for c_i is -1
            # c_i variables are indexed from n_points to 2*n_points - 1
            rows.append(num_constraints)
            cols.append(n_points + i) # Index for c_i
            data_A.append(-1.0)
            
            data_T.append(target_val)
            num_constraints += 1
    
    if num_constraints == 0:
        # This can happen if valid_pca_indices is small or neighbors lists are empty
        # or only contain self.
        # print("Warning: No constraints generated for the linear system. Returning default order.")
        return np.arange(n_points) # Default order if no constraints

    A = csr_matrix((data_A, (rows, cols)), shape=(num_constraints, 2 * n_points))
    T = np.array(data_T)
    # print(f"System built (M={num_constraints}, N={2*n_points}) in {time.time() - start_time:.2f}s.")

    # --- 6. Solve Linear System using LSQR ---
    # print("Solving linear system..."); start_time = time.time()
    # The solution vector z will contain [y_0, ..., y_{n-1}, c_0, ..., c_{n-1}]
    try:
        solution_z, istop, itn, r1norm, r2norm, anorm, acond, arnorm, xnorm, var = lsqr(
            A, T, show=False, iter_lim=min(A.shape)*2 # iter_lim can be tuned
        )
    except Exception as e:
        # print(f"LSQR failed: {e}. Returning default order.")
        return np.arange(n_points) # Default order on failure

    # print(f"LSQR took {time.time() - start_time:.2f}s. istop={istop}, itn={itn}")
    
    # Extract the y values (scalar parameterization)
    y = solution_z[:n_points]
    y = y - np.mean(y) # Center the parameterization

    # --- 7. Get Final Order & Evaluate ---
    return np.argsort(y)

# X = np.genfromtxt("HSMM_data_filtered_scaled_100.csv", delimiter=",")
# X = X[1:X.shape[0], 1:X.shape[1]]

# true_order = np.arange(X.T.shape[0])

# tau, p = kendalltau(recovered_order_stage, true_order)
# print("Kendall's Tau: ", tau, ", p-val: ", p)

X = np.array(pyreadr.read_r('embryo_expr.Rds')[None])
# recovered_order_stage = spectral_lin_reg(X, H_RADIUS = 10**20, N_MIN = 3)

# A_spectral = squareform(pdist(X, inverse_distance))
# spec_ord = mdso.SpectralOrdering()
# recovered_order_so = np.argsort(spec_ord.fit_transform(A_spectral))
# tau, p = kendalltau(recovered_order_so, true_order)
# print("Kendall's Tau: ", tau, ", p-val: ", p)

hour_collected = np.array(pyreadr.read_r("embryo_order.Rds")[None])
hour_collected = np.ndarray.flatten(hour_collected)

# umap_obj = umap.UMAP(n_components=2)
# umap_result = umap_obj.fit_transform(X)
# plt.scatter(umap_result[:, 0], umap_result[:, 1], c = hour_collected)
# plt.show()
# # plt.scatter(umap_result[:, 0], umap_result[:, 1], c = recovered_order_stage)
# # plt.show()
# tsne_obj = TSNE()
# tsne_result = tsne_obj.fit_transform(X)
# plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c = hour_collected)
# plt.show()
# # plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c = recovered_order_stage)
# # plt.show()

# print("Starting ICA")
# ica_obj = FastICA(n_components=3)
# ica_result = ica_obj.fit_transform(X)
# recovered_order_stage = spectral_lin_reg(ica_result, H_RADIUS = .5, N_MIN = 5)
# print("Ending ICA")
# fig = plt.figure()
# # syntax for 3-D projection
# ax = plt.axes(projection ='3d')
# ax.scatter(ica_result[:, 0], ica_result[:, 1], ica_result[:, 2], c = hour_collected, cmap = "viridis")
# plt.show()
# fig = plt.figure()
# # syntax for 3-D projection
# ax = plt.axes(projection ='3d')
# ax.scatter(ica_result[:, 0], ica_result[:, 1], ica_result[:, 2], c = recovered_order_stage, cmap = "viridis")
# plt.show()

pca_obj = PCA(n_components=2)
pca_result = pca_obj.fit_transform(X)
recovered_order_stage = spectral_lin_reg(pca_result, H_RADIUS = 50, N_MIN = 5)
print(recovered_order_stage)
print(len(recovered_order_stage) - recovered_order_stage)
plt.scatter(pca_result[:,0], pca_result[:,1], c = hour_collected, cmap = "viridis")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.show()
plt.scatter(pca_result[:,0], pca_result[:,1], c = recovered_order_stage, cmap = "viridis")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.colorbar(label = "Recovered Rank")
plt.show()



