# install once: pip install scanpy anndata pandas
import scanpy as sc
import pandas as pd
import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import FastICA, PCA
import umap
from stage3 import stage_embedding, evaluate_kendall_abs
from scipy.stats import rankdata
import os
import urllib.request

# download file if not in working directory (data from https://pmc.ncbi.nlm.nih.gov/articles/PMC7428862/)
url = "https://data.caltech.edu/records/b1kj4-nh475/files/packer2019.h5ad"
filename = "packer2019.h5ad"

if not os.path.exists(filename):
    print(f"Downloading {filename}...")
    urllib.request.urlretrieve(url, filename)
else:
    print(f"{filename} already exists.")

adata = sc.read_h5ad("packer2019.h5ad")  # path to the downloaded file

# filter to only seam cells as they display a clear trajectory 
seam_cell_data = adata[adata.obs["cell_type"] == "Seam_cell"]
order = rankdata(seam_cell_data.obs["embryo_time"])

umap_obj = umap.UMAP(n_components=2, n_neighbors=5)
umap_result = umap_obj.fit_transform(seam_cell_data.X)

import matplotlib.pyplot as plt

# view UMAP of cells colored by embryo time and rank of embryo time
plt.scatter(umap_result[:, 0], umap_result[:, 1], c = seam_cell_data.obs["embryo_time"])
plt.scatter(umap_result[:, 0], umap_result[:, 1], c = order)

from prepro import select_genes_by_l2_norm
from graph_utils import plot_graph_spring, build_radius_graph, smallest_connected_radius

X = seam_cell_data.X.toarray()
X_log   = np.log1p(X)                     # simple transform
#X_norm  = X_log / np.linalg.norm(X_log, axis=1, keepdims=True)
X_filt = select_genes_by_l2_norm(X_log, method="gmm",  plot=True)
# 1. row-L2 normalise for cosine distance
X_norm = X_filt / np.linalg.norm(X_filt, axis=1, keepdims=True)

from sklearn.decomposition import TruncatedSVD
svd    = TruncatedSVD(n_components=50, random_state=0)
X_pca  = svd.fit_transform(X_norm)               # n × 50

# X_norm already log-transformed and row-L2-normalised
r0, edges    = smallest_connected_radius(X_pca, verbose=True)

from experiment_utils import fiedler_permutation, spectral_ordering

umap_ordering_obj = umap.UMAP(n_components=1)
umap_ordering_result = np.ndarray.flatten(umap_ordering_obj.fit_transform(X_pca))
UMAP_order_indices = rankdata(umap_ordering_result)
tau = evaluate_kendall_abs(order, UMAP_order_indices)
print("UMAP Kendall's Tau: ", tau)

_, fiedler_order_indices = fiedler_permutation(X_pca, sigma = 200)
tau = evaluate_kendall_abs(order, fiedler_order_indices)
print("Fiedler Kendall's Tau: ", tau)

t_SNE_order_indices = rankdata(np.ndarray.flatten(TSNE(n_components = 1, perplexity = 30).fit_transform(X_pca)))
tau = evaluate_kendall_abs(order, t_SNE_order_indices)
print("tSNE Kendall's Tau: ", tau)

recanati_order_indices = np.argsort(spectral_ordering(X_pca, sigma = 300))
tau = evaluate_kendall_abs(order, recanati_order_indices)
print("Recanati Kendall's Tau: ", tau)

_, stage_order_indices = stage_embedding(X_pca,r=r0,neighbour_min=5,embedding="linreg")
tau = evaluate_kendall_abs(order, stage_order_indices)
print("STAGE Kendall's Tau: ", tau)

plt.scatter(umap_result[:, 0], umap_result[:, 1], c = stage_order_indices)