import scanpy as sc
import scvi
import numpy as np
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

# --- Load & preprocess (shared for both methods) ---
adata = sc.read_h5ad(
    "/home//Documents/mixed_diffusion/data/CITEseq/citeseq_preprocessed.h5ad"
)
# optional: adata.layers["counts"] = adata.X.copy()  # keep raw
sc.pp.filter_genes(adata, min_counts=3)
sc.pp.normalize_total(adata, target_sum=1e4)  # for HVG selection only
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor="seurat_v3", subset=True)
adata.raw = adata  # freeze HVGs
# IMPORTANT: pass RAW COUNTS to scVI; reattach counts if needed
# If you lost raw counts, reload or keep in a layer.

# --- scVI setup & train ---
scvi.model.SCVI.setup_anndata(
    adata,
    layer=None,  # set to "counts" if you stored raw in a layer
    batch_key="batch",  # omit if no batches
    labels_key="celltype",  # for convenience (not used in training)
)
model = scvi.model.SCVI(adata, n_latent=25, gene_likelihood="zinb")
model.train(
    max_epochs=400,
    early_stopping=True,
    early_stopping_patience=25,
    plan_kwargs={"lr": 1e-3},
)

# --- Extract scVI outputs ---
Z_scvi = model.get_latent_representation()
Xden_scvi = model.get_normalized_expression(
    library_size=1e4
)  # posterior mean "denoised"

# --- Clustering on latent (not UMAP) ---
adata.obsm["X_scvi"] = Z_scvi
sc.pp.neighbors(adata, use_rep="X_scvi", n_neighbors=30, metric="euclidean")
sc.tl.leiden(adata, key_added="leiden_scvi", resolution=1.0)

# --- Metrics (labels assumed in adata.obs["celltype"]) ---
y_true = adata.obs["celltype"].astype(str).values
y_pred = adata.obs["leiden_scvi"].astype(str).values
ari_scvi = adjusted_rand_score(y_true, y_pred)
nmi_scvi = normalized_mutual_info_score(y_true, y_pred)
print({"scvi_ARI": ari_scvi, "scvi_NMI": nmi_scvi})

# --- Masked recovery (simple example) ---
rng = np.random.default_rng(0)
X = adata.layers["counts"] if "counts" in adata.layers else adata.X.copy()
X = np.asarray(X)
mask = X > 0
idx = np.where(mask)
sel = rng.choice(len(idx[0]), size=int(0.1 * len(idx[0])), replace=False)
rows, cols = idx[0][sel], idx[1][sel]

X_masked = X.copy()
X_masked[rows, cols] = 0  # hide some true counts

# scVI prediction for masked entries: use Xden_scvi posterior mean scaled to counts scale if needed
# (Depending on your convention, align scales before scoring.)
true_vals = X[rows, cols]
pred_vals = Xden_scvi[rows, cols]
mse_scvi = np.mean((pred_vals - true_vals) ** 2)
print({"scvi_mask_MSE": mse_scvi})
