import scanpy as sc
import numpy as np
import scipy.sparse as sp
from sklearn.neighbors import KDTree
import torch
import os

def save_all_data(save_path):
    h5ad_path = "data/cellular_niche/st_data.h5ad"
    label_column = "cell_type"
    radius_um = 50.0
    n_neighbors = 11

    st_data = sc.read_h5ad(h5ad_path)
    coords = np.asarray(st_data.obsm["spatial"], dtype=np.float32)
    X_np = st_data.X.A if sp.issparse(st_data.X) else np.asarray(st_data.X, dtype=np.float32)
    N, d = X_np.shape

    gmin, gmax = X_np.min(0), X_np.max(0)
    X_np = 2.0 * (X_np - gmin) / np.maximum(gmax - gmin, 1e-8) - 1.0
    X_np = X_np * np.float32(1.0 / np.sqrt(d))

    device = "cuda" if torch.cuda.is_available() else "cpu"
    X_t = torch.from_numpy(X_np).to(device)

    tree = KDTree(coords)
    nbr_idx_list, nbr_dist_list = tree.query_radius(coords, r=radius_um, return_distance=True, sort_results=True)

    valid_centers = []
    neighbors_fix = []
    for i, (ii, dd) in enumerate(zip(nbr_idx_list, nbr_dist_list)):
        mask = ii != i
        ii = ii[mask]
        if ii.size >= n_neighbors:
            valid_centers.append(i)
            neighbors_fix.append(ii[:n_neighbors])

    valid_centers = np.asarray(valid_centers, dtype=np.int64)
    idx_fix = np.stack(neighbors_fix, axis=0).astype(np.int64)

    idx_fix_t = torch.as_tensor(idx_fix, device=device, dtype=torch.long)
    features = X_t[idx_fix_t]

    labels_series = st_data.obs.iloc[valid_centers][label_column]
    mask_ok = ~labels_series.isna()
    if not mask_ok.all():
        valid_centers = valid_centers[mask_ok.to_numpy()]
        idx_fix = idx_fix[mask_ok.to_numpy()]
        labels_series = labels_series[mask_ok]

    labels_str = labels_series.to_numpy()
    label_classes, labels_int = np.unique(labels_str, return_inverse=True)
    labels = torch.as_tensor(labels_int, device=device, dtype=torch.long)

    idx_fix_t = torch.as_tensor(idx_fix, device=device, dtype=torch.long)
    features = X_t[idx_fix_t]

    torch.save({
        "features": features.cpu(),
        "labels": labels.cpu(),
        "label_classes": label_classes,
        "centers": valid_centers
    }, os.path.join(save_path, "all_samples.pt"))


if __name__ == "__main__":
    save_path = "preprocessed_dataset/merfish"
    os.makedirs(save_path, exist_ok=True)
    save_all_data(save_path)
