import numpy as np
import scanpy as sc
import pandas as pd
import umap
import scib
from scib.metrics import silhouette_batch, silhouette


def integration_score_mean(
    adata,
    n_neighbors=100,  # Increased number of neighbors
    use_rep="X_emb",
    batch_key="batch",
    label_key="cell_type",
    embed="X_emb",
):
    sc.pp.neighbors(adata, n_neighbors=n_neighbors, use_rep=use_rep)
    gc = scib.metrics.graph_connectivity(adata, label_key=label_key)
    asw_b = silhouette_batch(
        adata,
        batch_key=batch_key,
        label_key=label_key,
        embed=embed,
    )
    asw_l = silhouette(
        adata,
        label_key=label_key,
        embed=embed,
    )
    scores = [gc, asw_b, asw_l]
    mean_score = float(np.mean(scores))
    return mean_score


# Load the data
adata = sc.read_h5ad("./input/cxg_immune_5k.h5ad")

# Perform UMAP on the normalized data
umap_model = umap.UMAP(
    n_neighbors=100, min_dist=0.05, random_state=42
)  # Increased n_neighbors
adata.obsm["X_emb"] = umap_model.fit_transform(
    adata.layers["normalized"].toarray()
)  # Ensure to convert to dense array if necessary

# Evaluate integration quality
score = integration_score_mean(adata)

# Print the evaluation score
print("Integration Score Mean:", score)

# Save the embeddings to submission file
submission = pd.DataFrame(adata.obsm["X_emb"], columns=["X1", "X2"])
submission.to_csv("./working/submission.csv", index=False)
