import pandas as pd
import numpy as np
import scanpy as sc
from scipy.spatial import distance_matrix

# Load the data
adata = sc.read_h5ad("./input/stereo_drosophila_e5_6_sim.h5ad")


# Function to compute a refined Moran's I with a dynamic Gaussian kernel
def morans_i(gene_expression, coords):
    n = len(gene_expression)
    mean_expr = np.mean(gene_expression)
    # Compute the weights based on distance
    dist_matrix = distance_matrix(coords, coords)

    # Dynamically adjust the standard deviation for the Gaussian kernel
    std_dev = np.std(dist_matrix) * 0.5  # Example adjustment factor
    weights = np.exp(-(dist_matrix**2) / (2 * std_dev**2))  # Gaussian kernel
    np.fill_diagonal(weights, 0)  # No self-weighting

    # Calculate the numerator and denominator for Moran's I
    numerator = np.sum(
        weights
        * (gene_expression[:, None] - mean_expr)
        * (gene_expression[None, :] - mean_expr)
    )
    denominator = np.sum(gene_expression * gene_expression) * np.sum(weights)

    return (n / np.sum(weights)) * (numerator / denominator)


# Identify spatially variable genes
pred_spatial_var_scores = []
coords = adata.obsm["spatial"]
moran_scores = []

for gene in adata.var_names:
    gene_expression = adata.layers["counts"][
        :, adata.var_names.get_loc(gene)
    ].A.flatten()
    # Apply robust scaling using median and IQR
    median_expr = np.median(gene_expression)
    q1 = np.percentile(gene_expression, 25)
    q3 = np.percentile(gene_expression, 75)
    iqr = q3 - q1
    gene_expression = (gene_expression - median_expr) / (
        iqr if iqr > 0 else 1
    )  # Avoid division by zero

    mi = morans_i(gene_expression, coords)
    moran_scores.append(mi)

# Setting dynamic threshold at the 75th percentile
threshold = np.percentile(moran_scores, 75)

# Classify genes based on the dynamic threshold
for mi in moran_scores:
    pred_spatial_var_scores.append(1 if mi > threshold else 0)

# Store predictions
adata.var["pred_spatial_var_score"] = pred_spatial_var_scores


# Evaluate the predictions
def spatial_correlation(adata):
    corr = adata.var["pred_spatial_var_score"].corr(
        adata.var["true_spatial_var_score"], method="kendall"
    )
    return 0.0 if pd.isna(corr) else corr


# Compute and print the evaluation metric
validation_metric = spatial_correlation(adata)
print(validation_metric)

# Save submission
submission = pd.DataFrame(
    {
        "gene": adata.var_names,
        "pred_spatial_var_score": adata.var["pred_spatial_var_score"],
    }
)
submission.to_csv("./working/submission.csv", index=False)
