import numpy as np
from scipy.spatial.distance import cdist, euclidean # Added euclidean for get_neighbors_and_pca
from scipy.linalg import eigh # For symmetric matrix eigenvalues/vectors
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import lsqr, eigsh # Added eigsh for the except block
from scipy.stats import rankdata
##############################
# Spectral Linear Regression #
# (Original version from high_dim_sims.py)
##############################

def get_neighbors_and_pca(point_idx, X, K_NEIGHBORS, N_min):
    """
    Finds neighbors of a point and performs PCA. Returns normalized tangent.
    (Original version from high_dim_sims.py)
    """
    # cdist requires X[[point_idx]] to be 2D, which it is.
    # And X to be 2D, which it is.
    distances = cdist(X[[point_idx]], X, metric='euclidean')[0]
    neighbor_indices = np.argsort(distances)[1:(K_NEIGHBORS + 1)]

    if len(neighbor_indices) < N_min:
        return neighbor_indices, None, None, False

    neighbors = X[neighbor_indices]
    neighbor_mean = np.mean(neighbors, axis=0)
    centered_neighbors = neighbors - neighbor_mean

    try:
        # Ensure there are enough samples for SVD relative to dimensions
        # If M < P in an M x P matrix, SVD still works but Vh will be PxP (or PxM if full_matrices=False was used differently)
        # For np.linalg.svd, if centered_neighbors is m x d:
        #   U is m x k, s is k, Vh is k x d, where k = min(m, d) if full_matrices=False
        if centered_neighbors.shape[0] >= 1: # Need at least one neighbor to compute SVD
            if centered_neighbors.shape[0] >= centered_neighbors.shape[1]: # More samples than dimensions or equal
                 _, _, Vh = np.linalg.svd(centered_neighbors, full_matrices=False)
                 tangent_vector = Vh[0, :] # First principal component
            else: # Fewer samples than dimensions
                 # SVD on Y (m x d, m < d) will give Vh as m x d. Vh[0,:] is still the PC.
                 # The original code had a condition to return False here. Let's keep it for faithfulness.
                 # This means if K_NEIGHBORS < d, PCA might be deemed invalid by this path.
                 # However, the N_MIN check (d+1) usually prevents this if K_NEIGHBORS is also >= d+1
                 # print(f"Warning: Point {point_idx} has {centered_neighbors.shape[0]} neighbors, < {centered_neighbors.shape[1]} dims. Original code returned False.")
                 return neighbor_indices, neighbor_mean, None, False # As per original logic for this branch
        else: # No valid neighbors or centered_neighbors is empty
            return neighbor_indices, neighbor_mean, None, False

    except np.linalg.LinAlgError:
        # print(f"SVD failed for point {point_idx}") # Optional: for debugging
        return neighbor_indices, neighbor_mean, None, False
    except ValueError as e: # Catches issues if centered_neighbors is empty or problematic
        # print(f"SVD ValueError for point {point_idx}: {e}")
        return neighbor_indices, neighbor_mean, None, False

    return neighbor_indices, neighbor_mean, tangent_vector, True


def spectral_lin_reg_knn(X, K_NEIGHBORS, N_MIN):
    """
    Original Spectral Linear Regression algorithm.
    (From high_dim_sims.py)
    """
    n_points, n_dims = X.shape
    # print("Calculating local PCA..."); start_time = time.time()
    local_pca_results = {}
    valid_pca_indices = []
    for i in range(n_points):
        neigh_idx, mean_i, v_i, valid = get_neighbors_and_pca(i, X, K_NEIGHBORS, N_MIN)
        if valid:
            local_pca_results[i] = {'neighbors': neigh_idx, 'mean': mean_i, 'tangent': v_i}
            valid_pca_indices.append(i)
        else:
            local_pca_results[i] = None # Store None if PCA was not valid
    # print(f"PCA took {time.time() - start_time:.2f}s for {len(valid_pca_indices)} points.")

    if not valid_pca_indices:
        # print(f"Warning: PCA failed for all points, or no points met N_MIN criteria. Valid PCA indices: {len(valid_pca_indices)}")
        # To handle this gracefully, one might return a default ordering or raise a specific error.
        # For now, let's return a random permutation or raise error as original might have implicitly.
        # Original code would raise ValueError due to empty W or issues in spectral orientation.
        # Let's make it explicit:
        if n_points > 0:
             return np.arange(n_points) # Fallback: return original order if no PCA data
        else:
             return np.array([], dtype=int) # No points

    # --- 4. Spectral Orientation ---
    # print("Spectral orientation..."); start_time = time.time()
    W = lil_matrix((n_points, n_points), dtype=float)
    for i in valid_pca_indices:
        res_i = local_pca_results[i]
        v_i = res_i['tangent']
        for j_idx in res_i['neighbors']: # Iterate over neighbor indices
            # Ensure neighbor j also had a valid PCA to get its tangent
            if j_idx in local_pca_results and local_pca_results[j_idx] is not None:
                res_j = local_pca_results[j_idx]
                v_j = res_j['tangent']
                W[i, j_idx] = np.dot(v_i, v_j)
            # Original didn't explicitly handle if W[i,j] was not set due to j not having valid PCA.
            # It would remain 0. For W[i,i], it was set to 1.0 IF i was in valid_pca_indices.
        W[i, i] = 1.0 # Self-similarity for diagonal

    W = (W + W.T) / 2.0 # Symmetrize
    W_csr = W.tocsr()

    if W_csr.nnz == 0: # Check if W is all zeros (e.g., if no valid neighbors or all dot products were zero)
        # print("Warning: Adjacency matrix W for spectral orientation is all zeros.")
        # Fallback for all-zero W: signs become all 1s, tangents remain as they are from PCA
        u_max = np.ones(n_points)
    else:
        try:
            # Original code behavior:
            eigenvalues, eigenvectors = eigh(W_csr.toarray()) # Potential memory issue for large n_points
            u_max = eigenvectors[:, -1]
        except (np.linalg.LinAlgError, MemoryError, ValueError): # Catch LinAlgError, MemoryError from .toarray(), ValueError if k is too large for eigsh
            try:
                # Fallback to sparse eigs if dense fails (original except block)
                # Ensure k is valid for eigsh: 0 < k < W_csr.shape[0]
                k_eigsh = min(1, W_csr.shape[0] -1) if W_csr.shape[0] > 1 else 0
                if k_eigsh > 0:
                    eigenvalues, eigenvectors = eigsh(W_csr, k=k_eigsh, which='LA', tol=1e-6, maxiter=n_points*5) # Added tol and maxiter
                    u_max = eigenvectors[:, 0] if eigenvectors.ndim > 1 else eigenvectors
                else: # Cannot use eigsh if k=0 (e.g. n_points=1)
                    u_max = np.ones(n_points)

            except (np.linalg.LinAlgError, TypeError) as e_sparse: #TypeError if k is bad for eigsh
                # print(f"Sparse eigh failed as well: {e_sparse}. Defaulting to u_max = ones.")
                u_max = np.ones(n_points) # Ultimate fallback

    signs_s = np.sign(u_max)
    signs_s[signs_s == 0] = 1.0 # Ensure no zero signs

    oriented_tangents = {}
    for i in valid_pca_indices:
        if local_pca_results[i] is not None and local_pca_results[i]['tangent'] is not None:
             oriented_tangents[i] = signs_s[i] * local_pca_results[i]['tangent']
    # print(f"Orientation took {time.time() - start_time:.2f}s.")

    # --- 5. Build Linear System Az = T ---
    # print("Building linear system..."); start_time = time.time()
    rows, cols, data_A, data_T = [], [], [], []
    num_constraints = 0
    for i in valid_pca_indices:
        res_i = local_pca_results[i]
        if res_i is None: continue # Skip if PCA failed for i
        mean_i = res_i['mean']
        v_i_oriented = oriented_tangents.get(i) # Get oriented tangent, might be None if i wasn't in oriented_tangents map
        if v_i_oriented is None: continue # Skip if no oriented tangent (e.g. PCA failed for i or tangent was None)

        for j_idx in res_i['neighbors']:
            if i == j_idx: continue # Skip self-loops for constraints

            # Ensure X[j_idx] is valid. This should be fine if j_idx is a valid index.
            target_val = np.dot(X[j_idx] - mean_i, v_i_oriented)

            rows.append(num_constraints)
            cols.append(j_idx) # y_j
            data_A.append(1.0)

            rows.append(num_constraints)
            cols.append(n_points + i) # -c_i (auxiliary variable)
            data_A.append(-1.0)

            data_T.append(target_val)
            num_constraints += 1

    if num_constraints == 0:
        # print("Warning: No constraints generated for the linear system.")
        # Fallback if no constraints: return original order or random
        return np.arange(n_points) if n_points > 0 else np.array([], dtype=int)

    A = csr_matrix((data_A, (rows, cols)), shape=(num_constraints, 2 * n_points))
    T = np.array(data_T)
    # print(f"System built (M={num_constraints}, N={2*n_points}) in {time.time() - start_time:.2f}s.")

    # --- 6. Solve Linear System using LSQR ---
    # print("Solving linear system..."); start_time = time.time()
    try:
        # iter_lim needs to be positive
        iter_lim_lsqr = max(1, min(A.shape) * 2) # Ensure iter_lim is at least 1
        solution_z, istop, itn, r1norm, r2norm, anorm, acond, arnorm, xnorm, var = lsqr(
            A, T, show=False, iter_lim=iter_lim_lsqr, atol=1e-8, btol=1e-8 # Added atol, btol for robustness
        )
    except Exception as e:
        # print(f"LSQR failed: {e}. Returning default order.")
        return np.arange(n_points) if n_points > 0 else np.array([], dtype=int)

    # print(f"LSQR took {time.time() - start_time:.2f}s. istop={istop}, itn={itn}")
    y = solution_z[:n_points]
    if len(y) > 0:
        y = y - np.mean(y) # Center the result
    else: # Should not happen if num_constraints > 0 and LSQR ran
        # print("Warning: LSQR result 'y' is empty.")
        return np.arange(n_points) if n_points > 0 else np.array([], dtype=int)


    # --- 7. Get Final Order & Evaluate ---
    return rankdata(y, method='average')

