#%%
import pickle
import numpy as np
from prepro import select_genes_by_l2_norm
from graph_utils import plot_graph_spring, build_radius_graph, smallest_connected_radius


data = pickle.load(open('data/expr_data.pkl', 'rb'))
X = np.array(data['expr'])
order = np.array(data['order'])

# raw counts → log1p
X_log   = np.log1p(X)                     # simple transform
#X_norm  = X_log / np.linalg.norm(X_log, axis=1, keepdims=True)
X_filt = select_genes_by_l2_norm(X_log, method="gmm", plot=True)
# 1. row-L2 normalise for cosine distance
X_norm = X_filt / np.linalg.norm(X_filt, axis=1, keepdims=True)

#%%
# 2. optional dimensionality reduction  (≈ 60× speed-up vs 8 000-D)
from sklearn.decomposition import TruncatedSVD
svd    = TruncatedSVD(n_components=20, random_state=0)
X_pca  = svd.fit_transform(X_norm)               # n × 50

# X_norm already log-transformed and row-L2-normalised
r0, edges    = smallest_connected_radius(X_pca, verbose=True)
#%%

#edges = build_radius_graph(X_pca, r=r0)# 5. visual check
plot_graph_spring(edges, iterations=100, node_size=10,
                  title=f"{edges.shape[0]:,} edges   |   {X_pca.shape[1]}-D SVD")

#%%
edges = build_radius_graph(X_pca, r=1.2*r0)# 5. visual check
plot_graph_spring(edges, iterations=100, node_size=10,
                  title=f"{edges.shape[0]:,} edges   |   {X_pca.shape[1]}-D SVD")

#%%
from matplotlib import pyplot as plt
# plt.hist(X_norm.flatten(), bins=100, color='salmon', edgecolor='black')
plt.hist(X_pca.flatten(), bins=100, color='salmon', edgecolor='black')
# %%

from stage3 import stage_embedding, evaluate_kendall_abs
inflation_grid = [1.0, 1.2, 1.5, 1.7, 2.0, 2.2, 2.5, 2.7, 3.0]
n_repeats      = 5

tau_results = {f: [] for f in inflation_grid}

for f in inflation_grid:
    for _ in range(n_repeats):
        _, order_idx = stage_embedding(
            X_pca,
            r=f * r0,
            neighbour_min=5,
            embedding="linreg",
        )
        tau = evaluate_kendall_abs(order, order_idx)
        tau_results[f].append(tau)

# ── Box-plot of Kendall’s τ vs. inflation factor ───────────────────────────
fig, ax = plt.subplots(figsize=(8, 4))
ax.boxplot(
    [tau_results[f] for f in inflation_grid],
    labels=[str(f) for f in inflation_grid],
    showmeans=True,
)
ax.set_xlabel("Radius inflation factor")
ax.set_ylabel("Kendall's τ")
ax.set_title("STAGE performance across radius inflations")
plt.tight_layout()
plt.show()
# %%
