# scrna_kmeans_gene2500_no_pca.py
import os, math, json
import numpy as np
import scanpy as sc
import scipy.sparse as sp
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors
import torch

H5AD_PATH     = "data/cellular_niche/sc_data.h5ad"
OUT_DIR       = "preprocessed_dataset/cellular_niche"
NUM_CLOUDS    = 2185
CLOUD_SIZE    = 69
LABEL_COL_PRI = ["cell_type", "cluster_label", "cell_label"]
RANDOM_STATE  = 42
SAVE_FLOAT16  = False

os.makedirs(OUT_DIR, exist_ok=True)
rng = np.random.default_rng(RANDOM_STATE)

adata = sc.read_h5ad(H5AD_PATH)
adata.var_names_make_unique()
print(adata)

TARGET_GENES = 2500
SUBSAMPLE    = 20000
MIN_DET_FRAC = 0.01

is_sp = sp.issparse(adata.X)
N = adata.n_obs
if is_sp:
    Xcsc = adata.X.tocsc()
    det_frac = np.diff(Xcsc.indptr).astype(np.float64) / float(N)
else:
    det_frac = (adata.X > 0).sum(axis=0) / float(N)
    det_frac = np.asarray(det_frac).ravel()
keep_det = det_frac >= MIN_DET_FRAC
if keep_det.sum() < adata.n_vars:
    adata = adata[:, keep_det].copy()
print(f"[filter] detection ≥ {MIN_DET_FRAC:.2%}: {adata.n_vars} genes")

rng = np.random.default_rng(42)
if adata.n_obs > SUBSAMPLE:
    rows = np.sort(rng.choice(adata.n_obs, size=SUBSAMPLE, replace=False))
    ad_sub = adata[rows].copy()
else:
    ad_sub = adata.copy()

ad_sub.layers['log'] = ad_sub.X.copy() if sp.issparse(ad_sub.X) else ad_sub.X.astype(np.float32, copy=True)
sc.pp.normalize_total(ad_sub, target_sum=1e4, layer='log')
sc.pp.log1p(ad_sub, layer='log')

sc.pp.highly_variable_genes(
    ad_sub, layer='log', flavor='seurat_v3',
    n_top_genes=TARGET_GENES, subset=False, inplace=True
)
mask_sub = ad_sub.var['highly_variable'].to_numpy()
genes = ad_sub.var_names[mask_sub].tolist()
print(f"[HVG] selected {len(genes)} genes on subset ({ad_sub.n_obs} cells)")

adata = adata[:, genes].copy()
print("[HVG] full data shape:", adata.shape)

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

Xn = adata.X.toarray() if sp.issparse(adata.X) else np.asarray(adata.X, dtype=np.float32)
d  = Xn.shape[1]
gmin = Xn.min(axis=0, keepdims=True)
gmax = Xn.max(axis=0, keepdims=True)
den  = np.maximum(gmax - gmin, 1e-8)
X_scaled = 2.0 * (Xn - gmin) / den - 1.0
X_scaled *= np.float32(1.0 / math.sqrt(d))
del Xn
print(f"[SCALE] gene-space shape = {X_scaled.shape} | floor={-1.0/math.sqrt(d):.4f}")

K = int(min(NUM_CLOUDS, X_scaled.shape[0]))
km = KMeans(n_clusters=K, n_init="auto", random_state=RANDOM_STATE)
km.fit(X_scaled)
labels_km = km.labels_
print(f"[KMeans] clusters={K}")

seeds = np.empty(K, dtype=np.int64)
for k in range(K):
    idx = np.where(labels_km == k)[0]
    if idx.size == 0:
        seeds[k] = rng.integers(0, X_scaled.shape[0])
        continue
    Ck = km.cluster_centers_[k]
    diff = X_scaled[idx] - Ck
    d2   = np.einsum("ij,ij->i", diff, diff)
    seeds[k] = idx[np.argmin(d2)]
print(f"[Seeds] {seeds.size} seeds selected")

nn = NearestNeighbors(n_neighbors=min(CLOUD_SIZE, X_scaled.shape[0]), metric="euclidean", algorithm="auto", n_jobs=-1)
nn.fit(X_scaled)
nbrs = nn.kneighbors(X_scaled[seeds], return_distance=False)

for i, s in enumerate(seeds):
    if nbrs[i][0] != s:
        pos = np.where(nbrs[i] == s)[0]
        if pos.size:
            nbrs[i][0], nbrs[i][pos[0]] = nbrs[i][pos[0]], nbrs[i][0]
        else:
            nbrs[i][-1] = s

X_clouds = X_scaled[nbrs]
if SAVE_FLOAT16:
    X_clouds = X_clouds.astype(np.float16, copy=False)

label_col = next((c for c in LABEL_COL_PRI if c in adata.obs.columns), None)

lab_cat = adata.obs[label_col].astype("category")
y_all = lab_cat.cat.codes.to_numpy(np.int64)
idx_to_label = {int(i): str(s) for i, s in enumerate(lab_cat.cat.categories)}
y_clouds = y_all[seeds].astype(np.int64)

print(f"[CLOUDS] X_clouds={X_clouds.shape} | y_clouds={y_clouds.shape} | label_col={label_col}")

np.savez_compressed(os.path.join(OUT_DIR, "clouds_gene2500.npz"),
                    X=X_clouds, y=y_clouds, seeds=seeds, neighbors=nbrs)
torch.save(torch.from_numpy(X_clouds), os.path.join(OUT_DIR, "X_clouds.pt"))
torch.save(torch.from_numpy(y_clouds), os.path.join(OUT_DIR, "y_clouds.pt"))
with open(os.path.join(OUT_DIR, "idx_to_label.json"), "w") as f:
    json.dump(idx_to_label, f, ensure_ascii=False, indent=2)
