import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
from scipy.linalg import eigh, svd
from scipy.sparse import lil_matrix, csr_matrix
from scipy.sparse.linalg import lsqr, eigsh
from scipy.stats import kendalltau
import time

# If spectral_ordering_robust is implemented in experiment_utils:
from experiment_utils import spectral_ordering_robust

# Apply a style first
plt.style.use('seaborn-v0_8-whitegrid')

# --- Set global font sizes using rcParams ---
plt.rcParams['axes.titlesize'] = 24
plt.rcParams['axes.labelsize'] = 20
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14
plt.rcParams['legend.fontsize'] = 16

# --- Parameters ---
N_POINTS = 200
NOISE_STD = 0.4  # Use a fixed noise level; you can later loop over a list of values
H_RADIUS = 1
N_MIN = 5
ROTATION_ANGLE = -np.pi / 3


# ========== 1. Generate and Rotate Data ==========
def generate_data(n, noise_std, angle):
    thetas_true = np.linspace(0, 2 * np.pi, n, endpoint=False)
    x_coords = thetas_true
    y_coords = 2 * np.sin(thetas_true)
    manifold_points = np.vstack((x_coords, y_coords)).T

    tangent_vectors = np.vstack(
        (np.ones_like(thetas_true), 2 * np.cos(thetas_true))
    ).T
    normal_vectors = np.vstack(
        (-2 * np.cos(thetas_true), np.ones_like(thetas_true))
    ).T
    norms = np.linalg.norm(normal_vectors, axis=1, keepdims=True)
    normalized_normal_vectors = normal_vectors / (norms + 1e-9)
    noise_magnitudes = np.random.normal(0, noise_std, size=(n, 1))
    noise = noise_magnitudes * normalized_normal_vectors
    noisy_points_on_manifold = manifold_points + noise

    true_order_indices = np.arange(n)

    if angle != 0:
        rotation_matrix = np.array([
            [np.cos(angle), -np.sin(angle)],
            [np.sin(angle),  np.cos(angle)],
        ])
        rotated_points = noisy_points_on_manifold @ rotation_matrix.T
    else:
        rotated_points = noisy_points_on_manifold

    print(
        f"Data generated (y = 2 sin(x) with normal noise std = {noise_std}) "
        f"and rotated by {np.degrees(angle):.1f} degrees."
    )
    return rotated_points, thetas_true, true_order_indices


X, thetas_true, true_order_indices = generate_data(
    N_POINTS, NOISE_STD, ROTATION_ANGLE
)
n_points, n_dims = X.shape


# ========== 2. Helper Function: Neighborhood and PCA ==========
def get_neighbors_and_pca(point_idx, X_data, h_radius, n_min_neighbors):
    distances = cdist(X_data[[point_idx]], X_data)[0]
    pca_neighbor_indices = np.where(distances <= h_radius)[0]
    n_pca_neighbors = len(pca_neighbor_indices)
    if n_pca_neighbors < n_min_neighbors:
        return pca_neighbor_indices, None, None, False

    neighbors_coords = X_data[pca_neighbor_indices]
    neighbor_mean = np.mean(neighbors_coords, axis=0)
    centered_neighbors = neighbors_coords - neighbor_mean

    try:
        if (
            centered_neighbors.shape[0] >= centered_neighbors.shape[1]
            and centered_neighbors.shape[0] > 1
        ):
            _, _, Vh = svd(centered_neighbors, full_matrices=False)
            tangent_vector = Vh[0, :]
            # Randomly flip the sign to simulate unoriented local tangents
            if np.random.rand() > 0.5:
                tangent_vector = -tangent_vector
        else:
            return pca_neighbor_indices, neighbor_mean, None, False
    except np.linalg.LinAlgError:
        return pca_neighbor_indices, neighbor_mean, None, False

    return pca_neighbor_indices, neighbor_mean, tangent_vector, True


# ========== 3. Pre-compute Local PCAs & Neighborhoods for the Loss ==========
print("Computing local PCA and neighborhoods for the loss...")
start_time_pca = time.time()
local_pca_results = {}
valid_pca_indices = []
all_neighbors_for_loss = [[] for _ in range(n_points)]
unoriented_tangents_for_plot = {}

for i in range(n_points):
    _, mean_i, v_i, valid_pca = get_neighbors_and_pca(
        i, X, H_RADIUS, N_MIN
    )
    if valid_pca:
        local_pca_results[i] = {"mean": mean_i, "tangent": v_i}
        valid_pca_indices.append(i)
        unoriented_tangents_for_plot[i] = v_i
    else:
        local_pca_results[i] = {"mean": mean_i, "tangent": None}

    distances_from_i = cdist(X[[i]], X)[0]
    all_neighbors_for_loss[i] = np.where(
        (distances_from_i <= H_RADIUS) & (distances_from_i > 1e-9)
    )[0].tolist()

print(f"PCA & loss neighborhood computation took {time.time() - start_time_pca:.2f}s.")
print(f"Number of points with valid PCA: {len(valid_pca_indices)}/{n_points}.")
if not valid_pca_indices:
    raise ValueError("PCA failed for a sufficient number of points.")


# ========== 4. Spectral Orientation ==========
print("Performing spectral orientation of tangent vectors...")
start_time_orient = time.time()
W_orient = lil_matrix((n_points, n_points), dtype=float)

for i in valid_pca_indices:
    v_i = local_pca_results[i]["tangent"]
    distances_for_W = cdist(X[[i]], X)[0]
    w_neighbor_indices = np.where(distances_for_W <= H_RADIUS)[0]
    for j_idx in w_neighbor_indices:
        if i == j_idx:
            W_orient[i, i] = 1.0
            continue
        if (
            j_idx in local_pca_results
            and local_pca_results[j_idx]["tangent"] is not None
        ):
            v_j = local_pca_results[j_idx]["tangent"]
            W_orient[i, j_idx] = np.dot(v_i, v_j)

W_orient = (W_orient + W_orient.T) / 2.0
W_orient_csr = W_orient.tocsr()
u_max = None

if W_orient_csr.nnz > 0:
    try:
        if n_points < 2000 or (W_orient_csr.nnz / (n_points ** 2) > 0.05):
            # Dense eigendecomposition
            _, eigenvectors = eigh(W_orient_csr.toarray())
            u_max = eigenvectors[:, -1]
        else:
            # Sparse eigendecomposition
            _, eigenvectors = eigsh(
                W_orient_csr, k=1, which="LA", tol=1e-6, maxiter=n_points * 10
            )
            u_max = eigenvectors[:, 0]
    except Exception as e:
        print(f"Orientation eigendecomposition failed: {e}, u_max will be None.")
else:
    print("W_orient matrix is empty, cannot perform spectral orientation.")

if u_max is None:
    print("Orientation failed or W_orient is empty, using default signs (all +1).")
    signs_s = np.ones(n_points)
else:
    signs_s = np.sign(u_max)
    signs_s[signs_s == 0] = 1

oriented_tangents = {}
for i in valid_pca_indices:
    if local_pca_results[i]["tangent"] is not None:
        initial_tangent = local_pca_results[i]["tangent"]
        final_sign = signs_s[i]
        oriented_tangents[i] = final_sign * initial_tangent

print(f"Spectral orientation took {time.time() - start_time_orient:.2f}s.")


# ========== 5. Plot Tangent Vectors (Before & After Alignment) ==========
def plot_tangent_vectors_before_alignment_colored(
    X_data,
    thetas_true_for_color,
    unoriented_tangents,
    spectral_signs,
    ax_handle,
    quiver_length_scale=0.50,
):
    ax_handle.scatter(
        X_data[:, 0],
        X_data[:, 1],
        s=20,
        alpha=0.4,
        c=thetas_true_for_color,
        cmap="viridis",
        zorder=1,
    )
    quiver_plot_length = H_RADIUS * quiver_length_scale

    kept_indices = [
        i for i, sign in spectral_signs.items()
        if sign == 1 and i in unoriented_tangents
    ]
    flipped_indices = [
        i for i, sign in spectral_signs.items()
        if sign == -1 and i in unoriented_tangents
    ]

    if kept_indices:
        vecs_kept = np.array([unoriented_tangents[i] for i in kept_indices])
        ax_handle.quiver(
            X_data[kept_indices, 0],
            X_data[kept_indices, 1],
            vecs_kept[:, 0] * quiver_plot_length,
            vecs_kept[:, 1] * quiver_plot_length,
            color="dimgray",
            alpha=0.7,
            width=0.0045,
            scale=1.0,
            scale_units="xy",
            angles="xy",
            pivot="mid",
            zorder=2,
            label="Initial Direction (Kept by Orientation)",
        )

    if flipped_indices:
        vecs_flipped = np.array([unoriented_tangents[i] for i in flipped_indices])
        ax_handle.quiver(
            X_data[flipped_indices, 0],
            X_data[flipped_indices, 1],
            vecs_flipped[:, 0] * quiver_plot_length,
            vecs_flipped[:, 1] * quiver_plot_length,
            color="deepskyblue",
            alpha=0.7,
            width=0.0045,
            scale=1.0,
            scale_units="xy",
            angles="xy",
            pivot="mid",
            zorder=2,
            label="Initial Direction (Flipped by Orientation)",
        )

    if kept_indices or flipped_indices:
        ax_handle.legend(loc="lower right", bbox_to_anchor=(1, 0.2))

    ax_handle.set_title("Local Tangent Vectors Before Orientation")
    ax_handle.set_xlabel("X_1 coordinate")
    ax_handle.set_ylabel("X_2 coordinate")
    ax_handle.grid(True, linestyle="--", alpha=0.7)
    ax_handle.tick_params(axis="both", which="major")


fig_unoriented_colored, ax_unoriented_colored = plt.subplots(figsize=(10, 8))
spectral_signs_map = {i: signs_s[i] for i in valid_pca_indices if i < len(signs_s)}

plot_tangent_vectors_before_alignment_colored(
    X,
    thetas_true,
    unoriented_tangents_for_plot,
    spectral_signs_map,
    ax_unoriented_colored,
    quiver_length_scale=0.50,
)
plot_xlim = ax_unoriented_colored.get_xlim()
plot_ylim = ax_unoriented_colored.get_ylim()


def plot_final_oriented_vectors(
    X_data,
    thetas_true_for_color,
    tangents_dict,
    title_str,
    ax_handle,
    color="crimson",
    quiver_length_scale=0.50,
):
    ax_handle.scatter(
        X_data[:, 0],
        X_data[:, 1],
        s=20,
        alpha=0.4,
        c=thetas_true_for_color,
        cmap="viridis",
        zorder=1,
    )
    plot_indices = list(tangents_dict.keys())
    if not plot_indices:
        ax_handle.set_title(title_str + " (No tangents)")
        return
    vecs_to_plot = np.array([tangents_dict[i] for i in plot_indices])
    quiver_plot_length = H_RADIUS * quiver_length_scale
    ax_handle.quiver(
        X_data[plot_indices, 0],
        X_data[plot_indices, 1],
        vecs_to_plot[:, 0] * quiver_plot_length,
        vecs_to_plot[:, 1] * quiver_plot_length,
        color=color,
        alpha=0.8,
        width=0.0045,
        scale=1.0,
        scale_units="xy",
        angles="xy",
        pivot="mid",
        zorder=2,
    )
    ax_handle.set_title(title_str)
    ax_handle.set_xlabel("X_1 coordinate")
    ax_handle.set_ylabel("X_2 coordinate")
    ax_handle.grid(True, linestyle="--", alpha=0.7)
    ax_handle.tick_params(axis="both", which="major")


fig_oriented, ax_oriented = plt.subplots(figsize=(10, 8))
plot_final_oriented_vectors(
    X,
    thetas_true,
    oriented_tangents,
    "Local Tangent Vectors After Orientation",
    ax_oriented,
    color="crimson",
    quiver_length_scale=0.50,
)
ax_oriented.set_xlim(plot_xlim)
ax_oriented.set_ylim(plot_ylim)


# ========== 6. Build and Solve Linear System (Our Embedding) ==========
print("Building linear system for 1D embedding...")
start_time_ls = time.time()
rows_A, cols_A, data_A, data_T = [], [], [], []
num_constraints = 0

for i in range(n_points):
    if i not in oriented_tangents:
        continue
    s_i_v_i = oriented_tangents[i]
    for j_idx in all_neighbors_for_loss[i]:
        if j_idx <= i:
            continue
        if j_idx not in oriented_tangents:
            continue
        s_j_v_j = oriented_tangents[j_idx]
        avg_oriented_tangent = 0.5 * (s_i_v_i + s_j_v_j)
        spatial_difference = X[j_idx] - X[i]
        target_val_Tij = np.dot(avg_oriented_tangent, spatial_difference)

        rows_A.append(num_constraints)
        cols_A.append(j_idx)
        data_A.append(1.0)

        rows_A.append(num_constraints)
        cols_A.append(i)
        data_A.append(-1.0)

        data_T.append(target_val_Tij)
        num_constraints += 1

y_embedding = np.arange(n_points, dtype=float)

if num_constraints > 0:
    A_matrix = csr_matrix(
        (data_A, (rows_A, cols_A)),
        shape=(num_constraints, n_points),
    )
    T_vector = np.array(data_T)
    print(
        f"Linear system built (M = {num_constraints}, N = {n_points}) "
        f"in {time.time() - start_time_ls:.2f}s."
    )
    print("Solving linear system for embedding (LSQR)...")
    system_solve_start_time = time.time()
    iter_limit = (
        max(2000, min(A_matrix.shape) * 2)
        if A_matrix.shape[0] > 0 and A_matrix.shape[1] > 0
        else 2000
    )
    solution_bundle = lsqr(
        A_matrix,
        T_vector,
        show=False,
        iter_lim=iter_limit,
        atol=1e-8,
        btol=1e-8,
    )
    y_embedding = solution_bundle[0]
    istop = solution_bundle[1]
    itn = solution_bundle[2]
    r1norm = solution_bundle[3]
    r2norm = solution_bundle[4]
    print(
        f"LSQR took {time.time() - system_solve_start_time:.2f}s. "
        f"istop={istop}, itn={itn}, r1norm={r1norm:.2e}, r2norm={r2norm:.2e}"
    )
    if not (1 <= istop <= 2) and istop != 7:
        print(f"Warning: LSQR convergence issue. istop = {istop}.")
    if y_embedding.size > 0:
        y_embedding = y_embedding - np.mean(y_embedding[~np.isnan(y_embedding)])
    else:
        print("Warning: y_embedding is empty after LSQR.")
        y_embedding = np.arange(n_points, dtype=float)
else:
    print("No constraints for LS embedding. Using default index order.")
    y_embedding = np.arange(n_points, dtype=float)


# ========== 7. Our Method: Get Final Order & Evaluate ==========
recovered_order_indices_ours = np.argsort(y_embedding)
tau_ours, p_value_ours = kendalltau(
    true_order_indices, recovered_order_indices_ours
)
print(
    f"\n[Our method] Kendall's Tau: {tau_ours:.4f} "
    f"(abs = {abs(tau_ours):.4f}, p-value = {p_value_ours:.4g})"
)


# ========== 8. Recanati Method on the SAME DATA ==========
# Following your previous code: use the noise std as sigma
sigma_recanati = NOISE_STD  # You can tune this factor if desired

recanati_scores = spectral_ordering_robust(X, sigma=sigma_recanati)
# spectral_ordering_robust returns a real-valued score per point; we sort by it
recanati_order_indices = np.argsort(recanati_scores)
tau_recanati, p_value_recanati = kendalltau(
    true_order_indices, recanati_order_indices
)
print(
    f"[Recanati] Kendall's Tau: {tau_recanati:.4f} "
    f"(abs = {abs(tau_recanati):.4f}, p-value = {p_value_recanati:.4g})"
)


# ========== 9. Visualization of Final Recovered Order (Our Method) ==========
fig_final, ax_final = plt.subplots(figsize=(10, 8))
recovered_ranks_ours = np.empty_like(recovered_order_indices_ours, dtype=float)
recovered_ranks_ours[recovered_order_indices_ours] = np.arange(n_points)

valid_y_mask = ~np.isnan(y_embedding)
if np.any(valid_y_mask):
    scatter1 = ax_final.scatter(
        X[valid_y_mask, 0],
        X[valid_y_mask, 1],
        c=recovered_ranks_ours[valid_y_mask],
        cmap="viridis",
        s=50,
        alpha=0.7,
        label="Valid Embedding",
        zorder=1,
    )
    if fig_final and scatter1:
        cb = fig_final.colorbar(
            scatter1, ax=ax_final, label="Recovered Rank (Our method)"
        )

if np.any(~valid_y_mask):
    ax_final.scatter(
        X[~valid_y_mask, 0],
        X[~valid_y_mask, 1],
        c="lightgrey",
        s=50,
        alpha=0.5,
        marker="x",
        label="Invalid Embedding (NaN)",
        zorder=1,
    )

if np.any(~valid_y_mask) and np.any(valid_y_mask) and ax_final:
    ax_final.legend(loc="best")

ax_final.set_title(
    f"Our Method: Recovered Order\n"
    f"Kendall tau = {tau_ours:.3f}, |tau| = {abs(tau_ours):.3f}"
)
ax_final.set_xlabel("X coordinate")
ax_final.set_ylabel("Y coordinate")
ax_final.grid(True, linestyle="--", alpha=0.7)
ax_final.tick_params(axis="both", which="major")
ax_final.set_xlim(plot_xlim)
ax_final.set_ylim(plot_ylim)


# ========== 10. Visualization of Recanati Recovered Order ==========
fig_rec, ax_rec = plt.subplots(figsize=(10, 8))
recanati_ranks = np.empty_like(recanati_order_indices, dtype=float)
recanati_ranks[recanati_order_indices] = np.arange(n_points)

scatter2 = ax_rec.scatter(
    X[:, 0],
    X[:, 1],
    c=recanati_ranks,
    cmap="plasma",
    s=50,
    alpha=0.7,
    zorder=1,
)
fig_rec.colorbar(
    scatter2, ax=ax_rec, label="Recovered Rank (Recanati)"
)
ax_rec.set_title(
    f"Recanati: Recovered Order\n"
    f"Kendall tau = {tau_recanati:.3f}, |tau| = {abs(tau_recanati):.3f}"
)
ax_rec.set_xlabel("X coordinate")
ax_rec.set_ylabel("Y coordinate")
ax_rec.grid(True, linestyle="--", alpha=0.7)
ax_rec.tick_params(axis="both", which="major")
ax_rec.set_xlim(plot_xlim)
ax_rec.set_ylim(plot_ylim)

plt.tight_layout()


# ========== 11. Save the four figures as EPS files ==========
fig_unoriented_colored.savefig(
    "before_align.eps", format="eps", bbox_inches="tight"
)
fig_oriented.savefig(
    "after_align.eps", format="eps", bbox_inches="tight"
)
fig_final.savefig(
    "recovered_order_ours.eps", format="eps", bbox_inches="tight"
)
fig_rec.savefig(
    "recovered_order_recanati.eps", format="eps", bbox_inches="tight"
)

plt.show()