import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
from scipy.linalg import eigh 
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import lsqr 
from scipy.stats import kendalltau
import time

# --- Parameters for the Sweep ---
N_POINTS = 3000 # As requested
H_RADIUS_VALUES = [0.4, 0.6, 0.8, 1.0, 1.2]
NOISE_STD_VALUES = [0.1, 0.2, 0.3, 0.4, 0.5] # This is your "sigma"
N_MIN_PCA = 5      
ROTATION_ANGLE = -np.pi / 3 
N_REPETITIONS = 20 # Repeat 10 times for each (NOISE_STD, H_RADIUS) pair

# --- 1. Generate and Rotate Data ---
def generate_data(n, current_noise_std, angle, run_seed=None): # Parameter renamed for clarity
    if run_seed is not None: 
        np.random.seed(run_seed)
    
    thetas_true = np.linspace(0, 2 * np.pi, n, endpoint=False) 
    x_coords = thetas_true
    y_coords = 2 * np.sin(thetas_true)
    manifold_points = np.vstack((x_coords, y_coords)).T
    # Use current_noise_std for this specific data generation
    noise = np.random.normal(0, current_noise_std, size=manifold_points.shape)
    noisy_points = manifold_points + noise
    true_order_indices = np.arange(n)
    rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)],
                                [np.sin(angle),  np.cos(angle)]])
    rotated_points = noisy_points @ rotation_matrix.T
    
    # True tangents are independent of noise, only depend on thetas and angle
    true_tangents_unrotated_unnormalized = np.vstack((np.ones_like(thetas_true), 2 * np.cos(thetas_true))).T
    norm_of_true_tangents = np.linalg.norm(true_tangents_unrotated_unnormalized, axis=1, keepdims=True)
    norm_of_true_tangents[norm_of_true_tangents == 0] = 1 
    true_tangents_unrotated_normalized = true_tangents_unrotated_unnormalized / norm_of_true_tangents
    true_rotated_tangents = true_tangents_unrotated_normalized @ rotation_matrix.T
    return rotated_points, thetas_true, true_order_indices, true_rotated_tangents

# --- STAGE Method Helper: Neighborhood and PCA (Strictly H_RADIUS) ---
def get_neighbors_and_pca_h_radius(point_idx, X_data, h_radius_param, n_min_neighbors_pca):
    distances = cdist(X_data[[point_idx]], X_data)[0]
    pca_neighbor_indices = np.where(distances <= h_radius_param)[0]
    n_pca_neighbors = len(pca_neighbor_indices)
    if n_pca_neighbors < n_min_neighbors_pca: return pca_neighbor_indices, None, None, False
    neighbors_coords = X_data[pca_neighbor_indices]
    neighbor_mean = np.mean(neighbors_coords, axis=0)
    centered_neighbors = neighbors_coords - neighbor_mean
    try:
        if centered_neighbors.shape[0] >= centered_neighbors.shape[1] and centered_neighbors.shape[0] > 1:
             if centered_neighbors.shape[1] == 2 and np.linalg.matrix_rank(centered_neighbors, tol=1e-9) < 1: 
                 return pca_neighbor_indices, neighbor_mean, None, False
             _, _, Vh = np.linalg.svd(centered_neighbors, full_matrices=False)
             return pca_neighbor_indices, neighbor_mean, Vh[0, :], True # Return tangent
        return pca_neighbor_indices, neighbor_mean, None, False
    except np.linalg.LinAlgError: return pca_neighbor_indices, neighbor_mean, None, False

# Utility function to evaluate ordering (handles potential sign flip for best absolute Tau)
def evaluate_and_flip_for_best_abs_tau(true_order, recovered_embedding_values):
    if recovered_embedding_values is None or np.all(np.isnan(recovered_embedding_values)) or recovered_embedding_values.size != true_order.size:
        return np.nan 
    current_indices = np.argsort(recovered_embedding_values)
    tau_original, _ = kendalltau(true_order, current_indices, nan_policy='propagate')
    indices_flipped = np.argsort(-recovered_embedding_values)
    tau_flipped, _ = kendalltau(true_order, indices_flipped, nan_policy='propagate')
    
    # Handle NaN cases explicitly when choosing best absolute
    abs_tau_original = abs(tau_original) if not np.isnan(tau_original) else -np.inf
    abs_tau_flipped = abs(tau_flipped) if not np.isnan(tau_flipped) else -np.inf

    if abs_tau_flipped > abs_tau_original:
        return tau_flipped # Return the signed Tau that produced the best absolute
    return tau_original # Return the original signed Tau (or the one that wasn't worse if one was NaN)

# --- Main Parameter Sweep ---
results_summary = {} # To store: results_summary[noise_std][h_radius] = (avg_abs_tau, std_abs_tau)

for current_noise_std_param in NOISE_STD_VALUES:
    results_summary[current_noise_std_param] = {}
    print(f"\n===== Processing for NOISE_STD (sigma) = {current_noise_std_param} =====")
    for current_h_radius_param in H_RADIUS_VALUES:
        print(f"\n  --- Processing for H_RADIUS = {current_h_radius_param} (Noise = {current_noise_std_param}) ---")
        
        signed_taus_for_this_setting = [] # Taus for the N_REPETITIONS

        for rep in range(N_REPETITIONS):
            # print(f"    Repetition {rep + 1}/{N_REPETITIONS}...") # Can be verbose
            
            # Generate data with current_noise_std_param and current_h_radius_param
            # run_seed=None for randomness in each of the 10 repetitions
            X, thetas_true, true_order_indices, _ = generate_data( # true_rotated_tangents not needed for STAGE value
                N_POINTS, current_noise_std_param, ROTATION_ANGLE, run_seed=None 
            )
            n_points, n_dims = X.shape

            if np.isnan(X).any() or np.isinf(X).any():
                print(f"      WARNING Rep {rep+1}: Input X NaN/Inf. Skipping STAGE for this rep.")
                signed_taus_for_this_setting.append(np.nan)
                continue

            # --- STAGE Algorithm (Strictly H_RADIUS based on current_h_radius_param) ---
            stage_start_time_rep = time.time()
            local_pca_results = {} 
            valid_pca_indices = [] 
            all_neighbors_for_loss = [[] for _ in range(n_points)]
            
            for i in range(n_points):
                _, mean_i, v_i, valid_pca = get_neighbors_and_pca_h_radius(
                    i, X, current_h_radius_param, N_MIN_PCA
                )
                if valid_pca: 
                    local_pca_results[i] = {'mean': mean_i, 'tangent': v_i}
                    valid_pca_indices.append(i)
                else: 
                    local_pca_results[i] = {'mean': mean_i, 'tangent': None}
                
                distances_from_i = cdist(X[[i]], X)[0]
                all_neighbors_for_loss[i] = np.where(
                    (distances_from_i <= current_h_radius_param) & (distances_from_i > 1e-9)
                )[0].tolist()
            
            y_embedding_stage_raw = np.full(n_points, np.nan)
            oriented_tangents = {} 

            if not valid_pca_indices or len(valid_pca_indices) < n_dims + 1 :
                # print(f"      STAGE Rep {rep+1}: PCA failed for sufficient points.") # Verbose
                pass # y_embedding_stage_raw remains NaN
            else:
                W = lil_matrix((n_points, n_points), dtype=float)
                for i_w in valid_pca_indices:
                    v_i_w = local_pca_results[i_w]['tangent']
                    distances_for_W = cdist(X[[i_w]], X)[0]
                    w_neighbor_indices = np.where(distances_for_W <= current_h_radius_param)[0]
                    for j_idx_w in w_neighbor_indices:
                        if i_w == j_idx_w: W[i_w, j_idx_w] = 1.0; continue
                        if j_idx_w in local_pca_results and local_pca_results[j_idx_w]['tangent'] is not None:
                            v_j_w = local_pca_results[j_idx_w]['tangent']; W[i_w, j_idx_w] = np.dot(v_i_w, v_j_w)
                W = (W + W.T) / 2.0; W = W.tocsr()
                u_max = None; signs_s = np.ones(n_points)
                try:
                    if W.nnz == 0: raise ValueError("W matrix empty.")
                    if n_points < 2000 or (W.nnz / (n_points**2) > 0.05 and n_points > 0): # Heuristic for dense/sparse
                        eigenvalues, eigenvectors = eigh(W.toarray())
                        if eigenvalues.size > 0: u_max = eigenvectors[:, -1]
                    else:
                        _, eigenvectors_s = eigsh(W, k=1, which='LA', tol=1e-6, maxiter=n_points*10,v0=np.ones(n_points))
                        if eigenvectors_s.size > 0 : u_max = eigenvectors_s[:, 0]
                except Exception: pass 
                if u_max is not None: signs_s = np.sign(u_max); signs_s[signs_s == 0] = 1
                
                for i_ot in valid_pca_indices:
                    if local_pca_results[i_ot]['tangent'] is not None:
                        oriented_tangents[i_ot] = signs_s[i_ot] * local_pca_results[i_ot]['tangent']
                
                rows_A, cols_A, data_A, data_T = [], [], [], []
                num_constraints = 0
                for i_ls in range(n_points):
                    if i_ls not in oriented_tangents: continue
                    s_i_v_i = oriented_tangents[i_ls]
                    for j_idx_ls in all_neighbors_for_loss[i_ls]: 
                        if j_idx_ls <= i_ls: continue
                        if j_idx_ls not in oriented_tangents: continue
                        s_j_v_j = oriented_tangents[j_idx_ls]
                        avg_oriented_tangent = 0.5 * (s_i_v_i + s_j_v_j)
                        target_val_Tij = np.dot(avg_oriented_tangent, X[j_idx_ls] - X[i_ls])
                        rows_A.append(num_constraints); cols_A.append(j_idx_ls); data_A.append(1.0)
                        rows_A.append(num_constraints); cols_A.append(i_ls); data_A.append(-1.0)
                        data_T.append(target_val_Tij); num_constraints += 1
                if num_constraints > 0:
                    A_matrix = csr_matrix((data_A, (rows_A, cols_A)), shape=(num_constraints, n_points))
                    T_vector = np.array(data_T)
                    try:
                        solution_bundle = lsqr(A_matrix, T_vector, show=False, iter_lim=max(2000, min(A_matrix.shape)*2), atol=1e-8, btol=1e-8)
                        y_embedding_stage_raw = solution_bundle[0] 
                        if not (y_embedding_stage_raw.size > 0 and np.any(~np.isnan(y_embedding_stage_raw))):
                             y_embedding_stage_raw = np.full(n_points, np.nan)
                    except Exception:  y_embedding_stage_raw = np.full(n_points, np.nan)
            
            # Evaluate STAGE embedding, getting the signed Tau that yields the best absolute
            tau_stage_signed_for_this_run = evaluate_and_flip_for_best_abs_tau(
                true_order_indices, y_embedding_stage_raw
            )
            signed_taus_for_this_setting.append(tau_stage_signed_for_this_run)
            # print(f"      STAGE Rep {rep+1} Tau: {tau_stage_signed_for_this_run:.4f}") # Verbose
        
        # Calculate average of ABSOLUTE Taus for this (NOISE_STD, H_RADIUS) setting
        abs_taus_for_this_setting = np.abs(np.array(signed_taus_for_this_setting))
        avg_abs_tau = np.nanmean(abs_taus_for_this_setting)
        std_abs_tau = np.nanstd(abs_taus_for_this_setting)
        results_summary[current_noise_std_param][current_h_radius_param] = (avg_abs_tau, std_abs_tau)
        
        print(f"    H={current_h_radius_param:.1f}, Noise={current_noise_std_param:.1f} -> Avg Abs Tau: {avg_abs_tau:.4f} (Std: {std_abs_tau:.4f})")

# --- Print Final Results Table ---
print("\n\n--- STAGE Parameter Sweep Results (Avg Absolute Kendall's Tau) ---")
header = "Noise_STD | " + " | ".join([f"H={h:.1f}  " for h in H_RADIUS_VALUES])
print(header)
print("-" * len(header))

for noise_val in NOISE_STD_VALUES:
    row_str = f"{noise_val:<9.1f} | "
    for h_val in H_RADIUS_VALUES:
        if h_val in results_summary[noise_val]:
            avg_tau, std_tau = results_summary[noise_val][h_val]
            row_str += f"{avg_tau:.3f} ({std_tau:.2f}) | "
        else:
            row_str += "  N/A   | " # Should not happen if all ran
    print(row_str)

print("-" * len(header))

# Optional: Find and print best parameter combination
best_overall_tau = -1.0
best_params = {}
for noise_val, h_results in results_summary.items():
    for h_val, (avg_tau, _) in h_results.items():
        if not np.isnan(avg_tau) and avg_tau > best_overall_tau:
            best_overall_tau = avg_tau
            best_params = {'noise_std': noise_val, 'h_radius': h_val}

if best_params:
    print(f"\nBest average absolute Tau: {best_overall_tau:.4f} achieved with Noise_STD={best_params['noise_std']:.1f} and H_RADIUS={best_params['h_radius']:.1f}")
else:
    print("\nCould not determine best parameters (all results might be NaN).")


# Visualization for ONE specific run (e.g., the first one, or from best params if we re-run)
# For this script, visualization is complex due to the sweep. 
# We'll skip plotting all 25*10 runs. 
# To visualize, you'd typically pick one set of parameters and run the script with N_REPETITIONS=1
# similar to the previous diagnostic scripts.

print("\nParameter sweep complete. To visualize, re-run with specific H_RADIUS and NOISE_STD and N_REPETITIONS=1.")