{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e3078a50-c234-4cce-abbc-f632adc54625",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb867c1e-fe69-40fd-a991-f1323a9debb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "proxy = 'http://127.0.0.1:7890'\n",
    "os.environ['http_proxy'] = proxy\n",
    "os.environ['HTTP_PROXY'] = proxy\n",
    "os.environ['https_proxy'] = proxy\n",
    "os.environ['HTTPS_PROXY'] = proxy\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "import random\n",
    "import warnings\n",
    "from datetime import datetime\n",
    "import gdown\n",
    "\n",
    "import anndata as ad\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scanpy as sc\n",
    "import scipy.sparse as sp\n",
    "import seaborn as sns\n",
    "import squidpy as sq\n",
    "from matplotlib import gridspec\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "\n",
    "from nichecompass.models import NicheCompass\n",
    "from nichecompass.utils import (add_gps_from_gp_dict_to_adata,\n",
    "                                create_new_color_dict,\n",
    "                                compute_communication_gp_network,\n",
    "                                visualize_communication_gp_network,\n",
    "                                extract_gp_dict_from_mebocost_es_interactions,\n",
    "                                extract_gp_dict_from_nichenet_lrt_interactions,\n",
    "                                extract_gp_dict_from_omnipath_lr_interactions,\n",
    "                                filter_and_combine_gp_dict_gps,\n",
    "                                generate_enriched_gp_info_plots)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c79e340a-e932-4fb6-8f93-ebe70033de90",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Dataset ###\n",
    "dataset = \"merfish\"\n",
    "species = \"mouse\"\n",
    "spatial_key = \"spatial\"\n",
    "n_neighbors = 4\n",
    "\n",
    "### Model ###\n",
    "# AnnData Keys\n",
    "counts_key = \"counts\"\n",
    "adj_key = \"spatial_connectivities\"\n",
    "cat_covariates_keys = [\"brain_section_label\"]\n",
    "gp_names_key = \"nichecompass_gp_names\"\n",
    "active_gp_names_key = \"nichecompass_active_gp_names\"\n",
    "gp_targets_mask_key = \"nichecompass_gp_targets\"\n",
    "gp_targets_categories_mask_key = \"nichecompass_gp_targets_categories\"\n",
    "gp_sources_mask_key = \"nichecompass_gp_sources\"\n",
    "gp_sources_categories_mask_key = \"nichecompass_gp_sources_categories\"\n",
    "latent_key = \"nichecompass_latent\"\n",
    "\n",
    "# Architecture\n",
    "cat_covariates_embeds_injection = [\"gene_expr_decoder\"]\n",
    "cat_covariates_embeds_nums = [3]\n",
    "cat_covariates_no_edges = [True]\n",
    "conv_layer_encoder = \"gcnconv\"\n",
    "active_gp_thresh_ratio = 0.01\n",
    "\n",
    "# Trainer\n",
    "n_epochs = 400\n",
    "n_epochs_all_gps = 25\n",
    "lr = 0.001\n",
    "lambda_edge_recon = 500000.\n",
    "lambda_gene_expr_recon = 300.\n",
    "lambda_l1_masked = 0. # increase if gene selection desired\n",
    "lambda_l1_addon = 100.\n",
    "edge_batch_size = 1024 # increase if more memory available\n",
    "n_sampled_neighbors = 4\n",
    "use_cuda_if_available = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5d8e5b8f-668b-4eed-a231-42e453430723",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Analysis ###\n",
    "cell_type_key = \"cluster_alias\"\n",
    "latent_leiden_resolution = 0.4\n",
    "latent_cluster_key = f\"latent_leiden_{str(latent_leiden_resolution)}\"\n",
    "sample_key = \"brain_section_label\"\n",
    "spot_size = 0.03\n",
    "differential_gp_test_results_key = \"nichecompass_differential_gp_test_results\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0a3569d3-3054-4849-8db8-4688d3002524",
   "metadata": {},
   "outputs": [],
   "source": [
    "warnings.filterwarnings(\"ignore\")\n",
    "# Get time of notebook execution for timestamping saved artifacts\n",
    "now = datetime.now()\n",
    "current_timestamp = now.strftime(\"%d%m%Y_%H%M%S\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8631cd9b-a0db-4083-987a-92f3210e6a3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define paths\n",
    "ga_data_folder_path = \"../../../data/gene_annotations\"\n",
    "gp_data_folder_path = \"../../../data/gene_programs\"\n",
    "so_data_folder_path = \"../../../data/spatial_omics\"\n",
    "omnipath_lr_network_file_path = f\"{gp_data_folder_path}/omnipath_lr_network.csv\"\n",
    "collectri_tf_network_file_path = f\"{gp_data_folder_path}/collectri_tf_network_{species}.csv\"\n",
    "nichenet_lr_network_file_path = f\"{gp_data_folder_path}/nichenet_lr_network_v2_{species}.csv\"\n",
    "nichenet_ligand_target_matrix_file_path = f\"{gp_data_folder_path}/nichenet_ligand_target_matrix_v2_{species}.csv\"\n",
    "mebocost_enzyme_sensor_interactions_folder_path = f\"{gp_data_folder_path}/metabolite_enzyme_sensor_gps\"\n",
    "gene_orthologs_mapping_file_path = f\"{ga_data_folder_path}/human_mouse_gene_orthologs.csv\"\n",
    "artifacts_folder_path = f\"../../../artifacts\"\n",
    "model_folder_path = f\"{artifacts_folder_path}/single_sample/{current_timestamp}/model\"\n",
    "figure_folder_path = f\"{artifacts_folder_path}/single_sample/{current_timestamp}/figures\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cf4188ad-4b2b-404b-b1f6-ea8b5bd144be",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(model_folder_path, exist_ok=True)\n",
    "os.makedirs(figure_folder_path, exist_ok=True)\n",
    "os.makedirs(so_data_folder_path, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c33339-b177-406f-8f24-aec2d648b735",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve OmniPath GPs (source: ligand genes; target: receptor genes)\n",
    "omnipath_gp_dict = extract_gp_dict_from_omnipath_lr_interactions(\n",
    "    species=species,\n",
    "    min_curation_effort=0,\n",
    "    load_from_disk=True,\n",
    "    save_to_disk=True,\n",
    "    lr_network_file_path=omnipath_lr_network_file_path,\n",
    "    gene_orthologs_mapping_file_path=gene_orthologs_mapping_file_path,\n",
    "    plot_gp_gene_count_distributions=True,\n",
    "    gp_gene_count_distributions_save_path=f\"{figure_folder_path}\" \\\n",
    "                                           \"/omnipath_gp_gene_count_distributions.svg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33579a52-6e8a-4699-b03f-af4d90ffd595",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display example OmniPath GP\n",
    "omnipath_gp_names = list(omnipath_gp_dict.keys())\n",
    "random.shuffle(omnipath_gp_names)\n",
    "omnipath_gp_name = omnipath_gp_names[0]\n",
    "print(f\"{omnipath_gp_name}: {omnipath_gp_dict[omnipath_gp_name]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e833b68-1148-428c-95b8-3655f24fdf91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve MEBOCOST GPs (source: enzyme genes; target: sensor genes)\n",
    "mebocost_gp_dict = extract_gp_dict_from_mebocost_es_interactions(\n",
    "    dir_path=mebocost_enzyme_sensor_interactions_folder_path,\n",
    "    species=species,\n",
    "    plot_gp_gene_count_distributions=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a1f43b3-8e10-42d7-8532-8a56c059d4c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display example MEBOCOST GP\n",
    "mebocost_gp_names = list(mebocost_gp_dict.keys())\n",
    "random.shuffle(mebocost_gp_names)\n",
    "mebocost_gp_name = mebocost_gp_names[0]\n",
    "print(f\"{mebocost_gp_name}: {mebocost_gp_dict[mebocost_gp_name]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "935e91c0-4169-4d6e-b5b4-81f2cc3a3501",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve NicheNet GPs (source: ligand genes; target: receptor genes, target genes)\n",
    "nichenet_gp_dict = extract_gp_dict_from_nichenet_lrt_interactions(\n",
    "    species=species,\n",
    "    version=\"v2\",\n",
    "    keep_target_genes_ratio=1.,\n",
    "    max_n_target_genes_per_gp=250,\n",
    "    load_from_disk=True,\n",
    "    save_to_disk=True,\n",
    "    lr_network_file_path=nichenet_lr_network_file_path,\n",
    "    ligand_target_matrix_file_path=nichenet_ligand_target_matrix_file_path,\n",
    "    gene_orthologs_mapping_file_path=gene_orthologs_mapping_file_path,\n",
    "    plot_gp_gene_count_distributions=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2589cc65-6e95-4f5f-ab6b-eccc272a6dc6",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Display example NicheNet GP\n",
    "nichenet_gp_names = list(nichenet_gp_dict.keys())\n",
    "random.shuffle(nichenet_gp_names)\n",
    "nichenet_gp_name = nichenet_gp_names[0]\n",
    "print(f\"{nichenet_gp_name}: {nichenet_gp_dict[nichenet_gp_name]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2a5dc002-7d4e-41cb-bcfc-73fa50d4527d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add GPs into one combined dictionary for model training\n",
    "combined_gp_dict = dict(omnipath_gp_dict)\n",
    "combined_gp_dict.update(mebocost_gp_dict)\n",
    "combined_gp_dict.update(nichenet_gp_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16e6e275-e57a-41d2-b552-4f66c52dd557",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter and combine GPs to avoid overlaps\n",
    "combined_new_gp_dict = filter_and_combine_gp_dict_gps(\n",
    "    gp_dict=combined_gp_dict,\n",
    "    gp_filter_mode=\"subset\",\n",
    "    combine_overlap_gps=True,\n",
    "    overlap_thresh_source_genes=0.9,\n",
    "    overlap_thresh_target_genes=0.9,\n",
    "    overlap_thresh_genes=0.9)\n",
    "\n",
    "print(\"Number of gene programs before filtering and combining: \"\n",
    "      f\"{len(combined_gp_dict)}.\")\n",
    "print(f\"Number of gene programs after filtering and combining: \"\n",
    "      f\"{len(combined_new_gp_dict)}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "bb17b40f-d449-49ed-960d-9098c3c1cd4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = '/stor/usr/sgenetmp/'\n",
    "basedir = '/nfs/public/usr/MERFISH2023/Zhuang-ABCA-2/processed/Sgeneration/'\n",
    "alldataname = [x[:-5] for x in sorted(os.listdir(basedir)) if not x.__contains__('_')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e9d02dfb-e59b-4722-9c0a-63608135127b",
   "metadata": {},
   "outputs": [],
   "source": [
    "rawadata = sc.read_h5ad('/nfs/public/usr/MERFISH2023/Zhuang-ABCA-2/Zhuang-ABCA-2-raw.h5ad')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "2cebb402-3e40-4cd5-9f94-8ebbdf4e78e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "annotationtable = pd.read_csv('/nfs/public/usr/MERFISH2023/Annotation/cluster_to_cluster_annotation_membership_pivoted.csv')\n",
    "annotationtable = annotationtable.set_index('cluster_alias')\n",
    "annotationcolor = pd.read_csv('/nfs/public/usr/MERFISH2023/Annotation/cluster_to_cluster_annotation_membership_color.csv')\n",
    "annotationcolor = annotationcolor.set_index('cluster_alias')\n",
    "annotation = pd.concat([annotationtable,annotationcolor],axis=1)\n",
    "regiontable = pd.read_csv('/nfs/public/usr/MERFISH2023/Annotation/parcellation_to_parcellation_term_membership_name.csv',index_col=0)\n",
    "regioncolor = pd.read_csv('/nfs/public/usr/MERFISH2023/Annotation/parcellation_to_parcellation_term_membership_color.csv',index_col=0)\n",
    "regionanno = pd.concat([regiontable,regioncolor],axis=1)\n",
    "ccfv1 = pd.read_csv('/nfs/public/usr/MERFISH2023/Annotation/A2ccf_coordinates.csv',index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbc937e4-bfbe-4941-94b2-2c02c2c53527",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "adata_batch_list = []\n",
    "\n",
    "for batch in alldataname:\n",
    "    print(f\"Processing batch {batch}...\")\n",
    "    print(\"Loading data...\")\n",
    "    adata_batch = sc.read_h5ad(\n",
    "        f\"{basedir}/{batch}.h5ad\")\n",
    "    \n",
    "    #filter\n",
    "    setidx = adata_batch.obs.index[adata_batch.obs.index.isin(ccfv1.index)]\n",
    "    if len(setidx)==0:\n",
    "        print(f'{batch} no ccf')\n",
    "        continue\n",
    "    adata_batch = adata_batch[setidx].copy()\n",
    "    adata_batch.obs['parcellation_index']=ccfv1.loc[adata_batch.obs.index,'parcellation_index']\n",
    "    adata_batch = adata_batch[adata_batch.obs['parcellation_index'] !=0]\n",
    "    adata_batch = adata_batch[adata_batch.obs['parcellation_index'] !=987]\n",
    "    query = regionanno.loc[adata_batch.obs.parcellation_index.values,:]\n",
    "    query.index = adata_batch.obs.index\n",
    "    adata_batch.obs = pd.concat([adata_batch.obs, query],axis=1)\n",
    "    query = annotation.loc[adata_batch.obs.cluster_alias.values,:]\n",
    "    query.index = adata_batch.obs.index\n",
    "    adata_batch.obs = pd.concat([adata_batch.obs, query],axis=1)\n",
    "    adata_batch.obs['region']=adata_batch.obs['structure'].astype(str)\n",
    "    adata_batch.obs.loc[adata_batch.obs['division']=='Isocortex','region']='Isocortex'\n",
    "    \n",
    "    \n",
    "    adata_batch.layers[counts_key]=rawadata[adata_batch.obs.index,:].X.copy()\n",
    "    print(\"Computing spatial neighborhood graph...\\n\")\n",
    "    # Compute (separate) spatial neighborhood graphs\n",
    "    sq.gr.spatial_neighbors(adata_batch,\n",
    "                            coord_type=\"generic\",\n",
    "                            spatial_key=spatial_key,\n",
    "                            n_neighs=n_neighbors)\n",
    "    \n",
    "    # Make adjacency matrix symmetric\n",
    "    adata_batch.obsp[adj_key] = (\n",
    "        adata_batch.obsp[adj_key].maximum(\n",
    "            adata_batch.obsp[adj_key].T))\n",
    "    adata_batch_list.append(adata_batch)\n",
    "adata = ad.concat(adata_batch_list, join=\"inner\")\n",
    "\n",
    "# Combine spatial neighborhood graphs as disconnected components\n",
    "batch_connectivities = []\n",
    "len_before_batch = 0\n",
    "for i in range(len(adata_batch_list)):\n",
    "    if i == 0: # first batch\n",
    "        after_batch_connectivities_extension = sp.csr_matrix(\n",
    "            (adata_batch_list[0].shape[0],\n",
    "            (adata.shape[0] -\n",
    "            adata_batch_list[0].shape[0])))\n",
    "        batch_connectivities.append(sp.hstack(\n",
    "            (adata_batch_list[0].obsp[adj_key],\n",
    "            after_batch_connectivities_extension)))\n",
    "    elif i == (len(adata_batch_list) - 1): # last batch\n",
    "        before_batch_connectivities_extension = sp.csr_matrix(\n",
    "            (adata_batch_list[i].shape[0],\n",
    "            (adata.shape[0] -\n",
    "            adata_batch_list[i].shape[0])))\n",
    "        batch_connectivities.append(sp.hstack(\n",
    "            (before_batch_connectivities_extension,\n",
    "            adata_batch_list[i].obsp[adj_key])))\n",
    "    else: # middle batches\n",
    "        before_batch_connectivities_extension = sp.csr_matrix(\n",
    "            (adata_batch_list[i].shape[0], len_before_batch))\n",
    "        after_batch_connectivities_extension = sp.csr_matrix(\n",
    "            (adata_batch_list[i].shape[0],\n",
    "            (adata.shape[0] -\n",
    "            adata_batch_list[i].shape[0] -\n",
    "            len_before_batch)))\n",
    "        batch_connectivities.append(sp.hstack(\n",
    "            (before_batch_connectivities_extension,\n",
    "            adata_batch_list[i].obsp[adj_key],\n",
    "            after_batch_connectivities_extension)))\n",
    "    len_before_batch += adata_batch_list[i].shape[0]\n",
    "adata.obsp[adj_key] = sp.vstack(batch_connectivities)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "0bfdf274-a924-453f-8752-a8e7620a3d37",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.var.index = adata_batch.var.gene_symbol.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "cf809f30-10b8-40c6-bc44-86267a16c435",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add the GP dictionary as binary masks to the adata\n",
    "add_gps_from_gp_dict_to_adata(\n",
    "    gp_dict=combined_new_gp_dict,\n",
    "    adata=adata,\n",
    "    gp_targets_mask_key=gp_targets_mask_key,\n",
    "    gp_targets_categories_mask_key=gp_targets_categories_mask_key,\n",
    "    gp_sources_mask_key=gp_sources_mask_key,\n",
    "    gp_sources_categories_mask_key=gp_sources_categories_mask_key,\n",
    "    gp_names_key=gp_names_key,\n",
    "    min_genes_per_gp=2,\n",
    "    min_source_genes_per_gp=1,\n",
    "    min_target_genes_per_gp=1,\n",
    "    max_genes_per_gp=None,\n",
    "    max_source_genes_per_gp=None,\n",
    "    max_target_genes_per_gp=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a89e1f2-451e-484b-a246-4c4e5e9e9395",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "from nichecompass.models import NicheCompass\n",
    "model = NicheCompass(adata,\n",
    "                     counts_key=counts_key,\n",
    "                     adj_key=adj_key,\n",
    "                     cat_covariates_embeds_injection=cat_covariates_embeds_injection,\n",
    "                     cat_covariates_keys=cat_covariates_keys,\n",
    "                     cat_covariates_no_edges=cat_covariates_no_edges,\n",
    "                     cat_covariates_embeds_nums=cat_covariates_embeds_nums,\n",
    "                     gp_names_key=gp_names_key,\n",
    "                     active_gp_names_key=active_gp_names_key,\n",
    "                     gp_targets_mask_key=gp_targets_mask_key,\n",
    "                     gp_targets_categories_mask_key=gp_targets_categories_mask_key,\n",
    "                     gp_sources_mask_key=gp_sources_mask_key,\n",
    "                     gp_sources_categories_mask_key=gp_sources_categories_mask_key,\n",
    "                     latent_key=latent_key,\n",
    "                     conv_layer_encoder=conv_layer_encoder,\n",
    "                     active_gp_thresh_ratio=active_gp_thresh_ratio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a3bca4c-da0f-4d73-b5b9-0ed08d9a146a",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "model.train(n_epochs=50,\n",
    "            n_epochs_all_gps=n_epochs_all_gps,\n",
    "            lr=lr,\n",
    "            lambda_edge_recon=lambda_edge_recon,\n",
    "            lambda_gene_expr_recon=lambda_gene_expr_recon,\n",
    "            lambda_l1_masked=lambda_l1_masked,\n",
    "            edge_batch_size=edge_batch_size,\n",
    "            n_sampled_neighbors=n_sampled_neighbors,\n",
    "            use_cuda_if_available=True,\n",
    "            verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5552e706-fbc1-4386-af14-2b0620f4eadf",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "for name in model.adata.obs.brain_section_label.unique():\n",
    "    print(name)\n",
    "    tmpadata = model.adata[model.adata.obs.brain_section_label==name].copy()\n",
    "    embnp = tmpadata.obsm['nichecompass_latent']\n",
    "    embadata = sc.AnnData(embnp,obs=tmpadata.obs,obsm=tmpadata.obsm)\n",
    "    embadata.write_h5ad(f'/stor/usr/sgenetmp/results/embedding/mouse2/nichecompass_{name}.h5ad')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "a92cac20-3d0f-44da-9448-2dbb19e56dd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = 'Zhuang-ABCA-2.031'\n",
    "tmpadata = model.adata[model.adata.obs.brain_section_label==name].copy()\n",
    "embnp = tmpadata.obsm['nichecompass_latent']\n",
    "embadata = sc.AnnData(embnp,obs=tmpadata.obs,obsm=tmpadata.obsm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7bc782-67dd-4b2f-ab65-e77a70c5d127",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pl.spatial(embadata,color='region',spot_size=0.03)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6565a97e-8a1c-4c43-a0b7-cab5aae1ac85",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "adata_batch_list = []\n",
    "\n",
    "for batch in alldataname:\n",
    "    print(f\"Processing batch {batch}...\")\n",
    "    print(\"Loading data...\")\n",
    "    adata_batch = sc.read_h5ad(\n",
    "        f\"{basedir}/{batch}.h5ad\")\n",
    "    \n",
    "    #filter\n",
    "    setidx = adata_batch.obs.index[adata_batch.obs.index.isin(ccfv1.index)]\n",
    "    if len(setidx)==0:\n",
    "        print(f'{batch} no ccf')\n",
    "        continue\n",
    "    adata_batch = adata_batch[setidx].copy()\n",
    "    adata_batch.obs['parcellation_index']=ccfv1.loc[adata_batch.obs.index,'parcellation_index']\n",
    "    adata_batch = adata_batch[adata_batch.obs['parcellation_index'] !=0]\n",
    "    adata_batch = adata_batch[adata_batch.obs['parcellation_index'] !=987]\n",
    "    query = regionanno.loc[adata_batch.obs.parcellation_index.values,:]\n",
    "    query.index = adata_batch.obs.index\n",
    "    adata_batch.obs = pd.concat([adata_batch.obs, query],axis=1)\n",
    "    query = annotation.loc[adata_batch.obs.cluster_alias.values,:]\n",
    "    query.index = adata_batch.obs.index\n",
    "    adata_batch.obs = pd.concat([adata_batch.obs, query],axis=1)\n",
    "    adata_batch.obs['region']=adata_batch.obs['structure'].astype(str)\n",
    "    adata_batch.obs.loc[adata_batch.obs['division']=='Isocortex','region']='Isocortex'\n",
    "    \n",
    "    \n",
    "    adata_batch.layers[counts_key]=rawadata[adata_batch.obs.index,:].X.copy()\n",
    "    print(\"Computing spatial neighborhood graph...\\n\")\n",
    "    # Compute (separate) spatial neighborhood graphs\n",
    "    sq.gr.spatial_neighbors(adata_batch,\n",
    "                            coord_type=\"generic\",\n",
    "                            spatial_key=spatial_key,\n",
    "                            n_neighs=50)  # large\n",
    "    \n",
    "    # Make adjacency matrix symmetric\n",
    "    adata_batch.obsp[adj_key] = (\n",
    "        adata_batch.obsp[adj_key].maximum(\n",
    "            adata_batch.obsp[adj_key].T))\n",
    "    adata_batch_list.append(adata_batch)\n",
    "adata = ad.concat(adata_batch_list, join=\"inner\")\n",
    "\n",
    "# Combine spatial neighborhood graphs as disconnected components\n",
    "batch_connectivities = []\n",
    "len_before_batch = 0\n",
    "for i in range(len(adata_batch_list)):\n",
    "    if i == 0: # first batch\n",
    "        after_batch_connectivities_extension = sp.csr_matrix(\n",
    "            (adata_batch_list[0].shape[0],\n",
    "            (adata.shape[0] -\n",
    "            adata_batch_list[0].shape[0])))\n",
    "        batch_connectivities.append(sp.hstack(\n",
    "            (adata_batch_list[0].obsp[adj_key],\n",
    "            after_batch_connectivities_extension)))\n",
    "    elif i == (len(adata_batch_list) - 1): # last batch\n",
    "        before_batch_connectivities_extension = sp.csr_matrix(\n",
    "            (adata_batch_list[i].shape[0],\n",
    "            (adata.shape[0] -\n",
    "            adata_batch_list[i].shape[0])))\n",
    "        batch_connectivities.append(sp.hstack(\n",
    "            (before_batch_connectivities_extension,\n",
    "            adata_batch_list[i].obsp[adj_key])))\n",
    "    else: # middle batches\n",
    "        before_batch_connectivities_extension = sp.csr_matrix(\n",
    "            (adata_batch_list[i].shape[0], len_before_batch))\n",
    "        after_batch_connectivities_extension = sp.csr_matrix(\n",
    "            (adata_batch_list[i].shape[0],\n",
    "            (adata.shape[0] -\n",
    "            adata_batch_list[i].shape[0] -\n",
    "            len_before_batch)))\n",
    "        batch_connectivities.append(sp.hstack(\n",
    "            (before_batch_connectivities_extension,\n",
    "            adata_batch_list[i].obsp[adj_key],\n",
    "            after_batch_connectivities_extension)))\n",
    "    len_before_batch += adata_batch_list[i].shape[0]\n",
    "adata.obsp[adj_key] = sp.vstack(batch_connectivities)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "590da02e-78cf-458f-aa2b-8852661ebead",
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.var.index = adata_batch.var.gene_symbol.values\n",
    "\n",
    "# Add the GP dictionary as binary masks to the adata\n",
    "add_gps_from_gp_dict_to_adata(\n",
    "    gp_dict=combined_new_gp_dict,\n",
    "    adata=adata,\n",
    "    gp_targets_mask_key=gp_targets_mask_key,\n",
    "    gp_targets_categories_mask_key=gp_targets_categories_mask_key,\n",
    "    gp_sources_mask_key=gp_sources_mask_key,\n",
    "    gp_sources_categories_mask_key=gp_sources_categories_mask_key,\n",
    "    gp_names_key=gp_names_key,\n",
    "    min_genes_per_gp=2,\n",
    "    min_source_genes_per_gp=1,\n",
    "    min_target_genes_per_gp=1,\n",
    "    max_genes_per_gp=None,\n",
    "    max_source_genes_per_gp=None,\n",
    "    max_target_genes_per_gp=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9966555-8eef-42ac-b6c0-6f872c74f13b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "from nichecompass.models import NicheCompass\n",
    "model = NicheCompass(adata,\n",
    "                     counts_key=counts_key,\n",
    "                     adj_key=adj_key,\n",
    "                     cat_covariates_embeds_injection=cat_covariates_embeds_injection,\n",
    "                     cat_covariates_keys=cat_covariates_keys,\n",
    "                     cat_covariates_no_edges=cat_covariates_no_edges,\n",
    "                     cat_covariates_embeds_nums=cat_covariates_embeds_nums,\n",
    "                     gp_names_key=gp_names_key,\n",
    "                     active_gp_names_key=active_gp_names_key,\n",
    "                     gp_targets_mask_key=gp_targets_mask_key,\n",
    "                     gp_targets_categories_mask_key=gp_targets_categories_mask_key,\n",
    "                     gp_sources_mask_key=gp_sources_mask_key,\n",
    "                     gp_sources_categories_mask_key=gp_sources_categories_mask_key,\n",
    "                     latent_key=latent_key,\n",
    "                     conv_layer_encoder=conv_layer_encoder,\n",
    "                     active_gp_thresh_ratio=active_gp_thresh_ratio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ce6ba65-7cdc-4a66-9d7a-808cf48bed94",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "model.train(n_epochs=50,\n",
    "            n_epochs_all_gps=n_epochs_all_gps,\n",
    "            lr=lr,\n",
    "            lambda_edge_recon=lambda_edge_recon,\n",
    "            lambda_gene_expr_recon=lambda_gene_expr_recon,\n",
    "            lambda_l1_masked=lambda_l1_masked,\n",
    "            edge_batch_size=edge_batch_size,\n",
    "            n_sampled_neighbors=n_sampled_neighbors,\n",
    "            use_cuda_if_available=True,\n",
    "            verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faa5e35d-7a76-4d9e-8652-1f236a874378",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "for name in model.adata.obs.brain_section_label.unique():\n",
    "    print(name)\n",
    "    tmpadata = model.adata[model.adata.obs.brain_section_label==name].copy()\n",
    "    embnp = tmpadata.obsm['nichecompass_latent']\n",
    "    embadata = sc.AnnData(embnp,obs=tmpadata.obs,obsm=tmpadata.obsm)\n",
    "    embadata.write_h5ad(f'/stor/usr/sgenetmp/results/embedding/mouse2/nichecompassN50_{name}.h5ad')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1b6ccf3-6764-41fb-9c9f-60b96dc35d95",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nichecompass",
   "language": "python",
   "name": "nichecompass"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
