{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17a92d1e-e702-4e64-8315-6aeeb91f0f63",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-12T03:38:29.623753Z",
     "iopub.status.busy": "2024-05-12T03:38:29.623056Z",
     "iopub.status.idle": "2024-05-12T03:38:32.625090Z",
     "shell.execute_reply": "2024-05-12T03:38:32.624561Z",
     "shell.execute_reply.started": "2024-05-12T03:38:29.623701Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "%pylab inline\n",
    "import scanpy as sc\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5fc8e230-8076-40d7-9139-91251644b4de",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-12T03:38:32.626762Z",
     "iopub.status.busy": "2024-05-12T03:38:32.626308Z",
     "iopub.status.idle": "2024-05-12T03:38:34.291451Z",
     "shell.execute_reply": "2024-05-12T03:38:34.290944Z",
     "shell.execute_reply.started": "2024-05-12T03:38:32.626733Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "os.chdir('../../')\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='7'\n",
    "os.environ['LOGURU_LEVEL']='INFO'\n",
    "import yaml\n",
    "import pathlib\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from dataset import SpatialSeq\n",
    "from inference import *\n",
    "\n",
    "def seed_all(seed, cuda_deterministic=False):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(seed)\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    if cuda_deterministic: # slower, more reproducible\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        torch.backends.cudnn.benchmark = False\n",
    "    else:\n",
    "        torch.backends.cudnn.deterministic = False\n",
    "        torch.backends.cudnn.benchmark = True\n",
    "seed_all(19491001,cuda_deterministic=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a0127a04-02c3-4f7d-87e1-7f6d59ad649e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-12T03:38:34.292526Z",
     "iopub.status.busy": "2024-05-12T03:38:34.292337Z",
     "iopub.status.idle": "2024-05-12T03:38:34.297501Z",
     "shell.execute_reply": "2024-05-12T03:38:34.296915Z",
     "shell.execute_reply.started": "2024-05-12T03:38:34.292510Z"
    }
   },
   "outputs": [],
   "source": [
    "def cluster_k_leiden(embadata,n_cluster,max_steps=50,this_min=0,this_max=10):\n",
    "    this_step = 0\n",
    "    print('reference cluster number',n_cluster)\n",
    "    while this_step < max_steps:\n",
    "        this_resolution = this_min + ((this_max-this_min)/2)\n",
    "        sc.tl.leiden(embadata,resolution=this_resolution,random_state=42)\n",
    "        this_clusters = embadata.obs['leiden'].nunique()\n",
    "        if this_clusters > n_cluster:\n",
    "            this_max = this_resolution\n",
    "        elif this_clusters < n_cluster:\n",
    "            this_min = this_resolution\n",
    "        else:break\n",
    "        this_step+=1\n",
    "    if this_step==max_steps:\n",
    "        print('Cannot find the number of clusters')\n",
    "        print('Use resolution',this_resolution)\n",
    "    else:\n",
    "        print('use resolution',this_resolution)\n",
    "     # leiden\n",
    "    sc.tl.leiden(embadata,resolution=this_resolution,random_state=42,key_added=f'cluster_{n_cluster}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4a452e61-6404-4d1d-bb7f-13197ffc76af",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-12T03:38:34.299046Z",
     "iopub.status.busy": "2024-05-12T03:38:34.298813Z",
     "iopub.status.idle": "2024-05-12T03:38:34.303359Z",
     "shell.execute_reply": "2024-05-12T03:38:34.302713Z",
     "shell.execute_reply.started": "2024-05-12T03:38:34.299027Z"
    }
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "def collate_fn(batch):\n",
    "    exp, coord,mask = zip(*batch)\n",
    "    padded_exp = pad_sequence(exp, batch_first=True)\n",
    "    padded_coord = pad_sequence(coord, batch_first=True)\n",
    "    padded_mask = pad_sequence(mask, batch_first=True)\n",
    "    return padded_exp, padded_coord, padded_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "58e445df-e511-4545-a8fc-bf48e731f8bb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-12T03:38:34.304885Z",
     "iopub.status.busy": "2024-05-12T03:38:34.304363Z",
     "iopub.status.idle": "2024-05-12T03:38:34.311782Z",
     "shell.execute_reply": "2024-05-12T03:38:34.311272Z",
     "shell.execute_reply.started": "2024-05-12T03:38:34.304852Z"
    },
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "datalist = ['/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.001.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.002.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.003.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.004.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.005.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.006.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.007.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.008.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.009.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.010.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.011.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.012.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.013.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.014.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.015.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.016.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.017.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.018.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.019.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.020.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.021.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.022.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.023.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.025.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.026.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.027.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.028.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.030.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.031.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.032.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.033.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.034.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.035.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.036.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.037.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.039.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.040.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.041.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.042.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.044.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.045.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.046.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.047.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.048.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.049.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.050.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.051.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.052.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.053.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.054.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.055.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.056.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.057.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.058.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.059.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.060.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.061.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.062.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.063.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.065.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.066.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.067.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.070.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.071.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.072.h5ad',\n",
    "'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration/Zhuang-ABCA-2.073.h5ad']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ebb89a54-6f5f-4ca5-aa60-e9e4bc1e4d0b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-12T03:38:35.088049Z",
     "iopub.status.busy": "2024-05-12T03:38:35.087192Z",
     "iopub.status.idle": "2024-05-12T03:38:36.385335Z",
     "shell.execute_reply": "2024-05-12T03:38:36.384668Z",
     "shell.execute_reply.started": "2024-05-12T03:38:35.087997Z"
    }
   },
   "outputs": [],
   "source": [
    "annotationtable = pd.read_csv('/data2/usr/Sgeneration/MERFISH/Annotation/cluster_to_cluster_annotation_membership_pivoted.csv')\n",
    "annotationtable = annotationtable.set_index('cluster_alias')\n",
    "annotationcolor = pd.read_csv('/data2/usr/Sgeneration/MERFISH/Annotation/cluster_to_cluster_annotation_membership_color.csv')\n",
    "annotationcolor = annotationcolor.set_index('cluster_alias')\n",
    "annotation = pd.concat([annotationtable,annotationcolor],axis=1)\n",
    "\n",
    "regiontable = pd.read_csv('/data2/usr/Sgeneration/MERFISH/Annotation/parcellation_to_parcellation_term_membership_name.csv',index_col=0)\n",
    "regioncolor = pd.read_csv('/data2/usr/Sgeneration/MERFISH/Annotation/parcellation_to_parcellation_term_membership_color.csv',index_col=0)\n",
    "regionanno = pd.concat([regiontable,regioncolor],axis=1)\n",
    "ccfv1 = pd.read_csv('/data2/usr/Sgeneration/MERFISH/Annotation/A2ccf_coordinates.csv',index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70d15a65-a6bb-4138-9bcd-929c429ed8cd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-12T03:38:37.088255Z",
     "iopub.status.busy": "2024-05-12T03:38:37.087361Z",
     "iopub.status.idle": "2024-05-12T03:38:40.051773Z",
     "shell.execute_reply": "2024-05-12T03:38:40.051184Z",
     "shell.execute_reply.started": "2024-05-12T03:38:37.088187Z"
    }
   },
   "outputs": [],
   "source": [
    "config_file = pathlib.Path('./dir/TARGET_CODE3000_MERFISH_base05_sinu_R3_corner_L8H8_sinu_5e4_v142_mouse1_noh_multimlp_2024-04-29/hparams.yaml')\n",
    "class setting( object ):\n",
    "    pass\n",
    "cfg=setting()\n",
    "if config_file.exists():\n",
    "    with config_file.open('r') as f:\n",
    "        d = yaml.unsafe_load(f)\n",
    "        for k,v in d.items():\n",
    "            setattr(cfg, k, v)\n",
    "\n",
    "cfg.ckpt_path = './dir/TARGET_CODE3000_MERFISH_base05_sinu_R3_corner_L8H8_sinu_5e4_v142_mouse1_noh_multimlp_2024-04-29/ckpt/ckpt_best.pt'\n",
    "\n",
    "setattr(cfg,'task','mse')\n",
    "batch_size = cfg.batch_size\n",
    "block_size = cfg.block_size\n",
    "task = cfg.task\n",
    "    \n",
    "n_layer = cfg.n_layer\n",
    "n_head = cfg.n_head\n",
    "n_embd = cfg.n_embd\n",
    "bias = cfg.bias\n",
    "dropout = cfg.dropout\n",
    "train_mode = cfg.train_mode\n",
    "init_from = cfg.init_from\n",
    "device_set = cfg.device\n",
    "dtype = cfg.dtype\n",
    "data_path = cfg.data_path\n",
    "ckpt_path = cfg.ckpt_path\n",
    "infersave_path = cfg.infersave_path\n",
    "N = cfg.N\n",
    "vocab_size = cfg.vocab_size\n",
    "torch.manual_seed(19491001)\n",
    "ckpt = torch.load(ckpt_path)\n",
    "del ckpt['model']['codebook.weight']\n",
    "del ckpt['model']['criterion.codebook.weight']\n",
    "print(f'load cfg.ckpt_path:{cfg.ckpt_path}')\n",
    "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
    "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
    "device_type = 'cuda'\n",
    "\n",
    "# model init\n",
    "from model import DaoConfig, GeST\n",
    "model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=4096,batch_size=batch_size,\n",
    "                    bias=bias,dropout=dropout,train_mode=train_mode,task=task,vocab_size=vocab_size,loss_len=cfg.loss_len,\n",
    "                    encoder = cfg.encoder, decoder = cfg.decoder,\n",
    "                    skipconnect=cfg.skipconnect, noise = cfg.noise, rope_base = cfg.rope_base,loc_emb = cfg.loc_emb,device_type=cfg.device,modeltype=cfg.model,\n",
    "                 codebook = cfg.codebook) # start with model_args from command line\n",
    "gptconf = DaoConfig(**model_args)\n",
    "model = GeST(gptconf)\n",
    "model.load_state_dict(ckpt['model'])\n",
    "model.to(device_type)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fcb0318-9ece-47bf-b8ce-9b4c2dcd97fb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-13T06:49:01.519762Z",
     "iopub.status.busy": "2024-05-13T06:49:01.518787Z",
     "iopub.status.idle": "2024-05-13T06:49:30.261321Z",
     "shell.execute_reply": "2024-05-13T06:49:30.260383Z",
     "shell.execute_reply.started": "2024-05-13T06:49:01.519679Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for name in datalist:\n",
    "    datasetname = name.split('/')[-1].split('.h5ad')[0]\n",
    "    adatahvg = sc.read_h5ad(name)\n",
    "    print(datasetname,name)\n",
    "    #filter\n",
    "    setidx = adatahvg.obs.index[adatahvg.obs.index.isin(ccfv1.index)]\n",
    "    if len(setidx)==0:\n",
    "        print(f'{datasetname} no ccf')\n",
    "        continue\n",
    "    adatahvg = adatahvg[setidx].copy()\n",
    "    adatahvg.obs['parcellation_index']=ccfv1.loc[adatahvg.obs.index,'parcellation_index']\n",
    "    adatahvg = adatahvg[adatahvg.obs['parcellation_index'] !=0]\n",
    "    adatahvg = adatahvg[adatahvg.obs['parcellation_index'] !=987]\n",
    "    query = regionanno.loc[adatahvg.obs.parcellation_index.values,:]\n",
    "    query.index = adatahvg.obs.index\n",
    "    adatahvg.obs = pd.concat([adatahvg.obs, query],axis=1)\n",
    "    query = annotation.loc[adatahvg.obs.cluster_alias.values,:]\n",
    "    query.index = adatahvg.obs.index\n",
    "    adatahvg.obs = pd.concat([adatahvg.obs, query],axis=1)\n",
    "    adatahvg.obs['region']=adatahvg.obs['structure'].astype(str)\n",
    "    adatahvg.obs.loc[adatahvg.obs['division']=='Isocortex','region']='Isocortex'\n",
    "    adatahvg.write_h5ad(f'/data1/usr/results/embedding/mouse2/raw_{datasetname}.h5ad')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0cdf0e0-921e-4b3e-9b9d-451352508d62",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-12T03:38:58.076880Z",
     "iopub.status.busy": "2024-05-12T03:38:58.075897Z",
     "iopub.status.idle": "2024-05-12T09:37:49.371526Z",
     "shell.execute_reply": "2024-05-12T09:37:49.370864Z",
     "shell.execute_reply.started": "2024-05-12T03:38:58.076822Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# name = datalist[3]\n",
    "for name in datalist:\n",
    "    datasetname = name.split('/')[-1].split('.h5ad')[0]\n",
    "    adatahvg = sc.read_h5ad(name)\n",
    "    print(datasetname,name)\n",
    "    #filter\n",
    "    setidx = adatahvg.obs.index[adatahvg.obs.index.isin(ccfv1.index)]\n",
    "    if len(setidx)==0:\n",
    "        print(f'{datasetname} no ccf')\n",
    "        continue\n",
    "    adatahvg = adatahvg[setidx].copy()\n",
    "    adatahvg.obs['parcellation_index']=ccfv1.loc[adatahvg.obs.index,'parcellation_index']\n",
    "    adatahvg = adatahvg[adatahvg.obs['parcellation_index'] !=0]\n",
    "    adatahvg = adatahvg[adatahvg.obs['parcellation_index'] !=987]\n",
    "    query = regionanno.loc[adatahvg.obs.parcellation_index.values,:]\n",
    "    query.index = adatahvg.obs.index\n",
    "    adatahvg.obs = pd.concat([adatahvg.obs, query],axis=1)\n",
    "    query = annotation.loc[adatahvg.obs.cluster_alias.values,:]\n",
    "    query.index = adatahvg.obs.index\n",
    "    adatahvg.obs = pd.concat([adatahvg.obs, query],axis=1)\n",
    "    adatahvg.obs['region']=adatahvg.obs['structure'].astype(str)\n",
    "    adatahvg.obs.loc[adatahvg.obs['division']=='Isocortex','region']='Isocortex'\n",
    "    \n",
    "    for round0 in [0.1,0.2,0.3]:\n",
    "        setattr(cfg,'round0',round0)\n",
    "        setattr(cfg,'round1',round0)\n",
    "        cell_ds = SpatialTarget(adatahvg,cfg=cfg,shuffle=None)\n",
    "        allcemb = np.zeros([len(cell_ds),cfg.n_embd])\n",
    "        for idx in range(len(cell_ds)):\n",
    "            nboridx = cell_ds.nbordict[tuple(cell_ds.rawcoord[idx,:])].copy()\n",
    "            nboridx.append(idx)\n",
    "            cell_ds.nbordict[tuple(cell_ds.rawcoord[idx,:])] = nboridx\n",
    "        localidx=0\n",
    "        loader = DataLoader(cell_ds, batch_size=64,pin_memory=True,num_workers=2,collate_fn=collate_fn,prefetch_factor=1,shuffle=False)\n",
    "        with torch.no_grad():\n",
    "            for iter_num, (exp,coord,mask) in enumerate(loader):\n",
    "                blocksize = mask.shape[1]\n",
    "                exp,coord,mask = exp.to(device_type).float(), coord.to(device_type).float(), mask.to(device_type)\n",
    "                _, tmpemb = model(inputs_embeds=exp,coord=coord,mask=mask,embedding=True)\n",
    "                tmpemb = tmpemb.detach().cpu().numpy()\n",
    "                mask = mask.detach().cpu().numpy()\n",
    "                for i in range(tmpemb.shape[0]):\n",
    "                    cellemb = tmpemb[i,blocksize:][mask[i]][-2]\n",
    "                    allcemb[localidx]=cellemb\n",
    "                    localidx+=1\n",
    "        adataemb = sc.AnnData(np.array(allcemb),obs=adatahvg.obs)\n",
    "        adataemb.obsm['spatial']=adatahvg.obsm['spatial']\n",
    "        sc.pp.pca(adataemb)\n",
    "        sc.pp.neighbors(adataemb)\n",
    "        sc.tl.umap(adataemb)\n",
    "        numcls = len(adataemb.obs['region'].unique())\n",
    "        cluster_k_leiden(adataemb,numcls)\n",
    "        numcls = len(adataemb.obs['division'].unique())\n",
    "        cluster_k_leiden(adataemb,numcls)\n",
    "        from sklearn.metrics.cluster import adjusted_mutual_info_score\n",
    "        numcls = len(adataemb.obs['region'].unique())\n",
    "        ami1 = adjusted_mutual_info_score(adataemb.obs['region'].values.tolist(),adataemb.obs[f'cluster_{numcls}'].values.tolist())\n",
    "        numcls = len(adataemb.obs['division'].unique())\n",
    "        ami2 = adjusted_mutual_info_score(adataemb.obs['division'].values.tolist(),adataemb.obs[f'cluster_{numcls}'].values.tolist())\n",
    "        print(f'_{round0}_{datasetname} ARI:',ami1,ami2)\n",
    "        adataemb.write_h5ad(f'/data1/usr/results/embedding/mouse2/adataemb_{round0}_{datasetname}.h5ad')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
