{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a96d17b-3802-4423-8103-b1a5acd75c3f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-14T18:40:45.144229Z",
     "iopub.status.busy": "2024-09-14T18:40:45.143731Z",
     "iopub.status.idle": "2024-09-14T18:41:15.510294Z",
     "shell.execute_reply": "2024-09-14T18:41:15.508292Z",
     "shell.execute_reply.started": "2024-09-14T18:40:45.144187Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "%pylab inline\n",
    "import scanpy as sc\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import scipy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a609b1f-9d90-40e3-975d-e6a1bdbc35d5",
   "metadata": {},
   "source": [
    "# Load Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "966493ac-926b-4318-861d-d0e4e9eb8536",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "os.chdir('../../')\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='7'\n",
    "os.environ['LOGURU_LEVEL']='INFO'\n",
    "import yaml\n",
    "import pathlib\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from dataset import SpatialSeq\n",
    "from inference import *\n",
    "\n",
    "def seed_all(seed, cuda_deterministic=False):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(seed)\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    if cuda_deterministic: # slower, more reproducible\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        torch.backends.cudnn.benchmark = False\n",
    "    else:\n",
    "        torch.backends.cudnn.deterministic = False\n",
    "        torch.backends.cudnn.benchmark = True\n",
    "seed_all(19491001,cuda_deterministic=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb1da451-0858-4acb-a684-df1e70716bcb",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7323bc49-28a7-46f8-b97d-6d64ef2034d9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from dataset import MultiSpatialTarget,SpatialTarget\n",
    "from inference import expandTarget_quantize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5db13a33-bfa9-4b1a-8692-f89e1eeb058d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sc.settings.figdir = '/stor/usr/sgenetmp/'\n",
    "savedir = '/stor/usr/sgenetmp/'\n",
    "basedir = '/MERFISH2023/Zhuang-ABCA-2/processed/Sgeneration/'\n",
    "alldataname = [x for x in sorted(os.listdir(basedir)) if not x.__contains__('_')]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb31ec90-807c-4663-87e3-d98499e56f58",
   "metadata": {},
   "source": [
    "# MLP Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270bbf39-c970-4a4c-b181-7d5a596f4c8a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-04-29T13:54:27.018355Z",
     "iopub.status.busy": "2024-04-29T13:54:27.017662Z",
     "iopub.status.idle": "2024-04-29T15:51:56.504091Z",
     "shell.execute_reply": "2024-04-29T15:51:56.503158Z",
     "shell.execute_reply.started": "2024-04-29T13:54:27.018305Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.neural_network import MLPRegressor\n",
    "for adataname in alldataname:\n",
    "    adataname = adataname[:-5]\n",
    "    refadata = sc.read_h5ad(f'{savedir}{adataname}_ref.h5ad')\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    valcoord = valadata.obs[['array_col','array_row']].values\n",
    "    genelist = valadata.var.index.tolist()\n",
    "    mlpregr = MLPRegressor(random_state=0).fit(refadata.obs[['array_col','array_row']].values,refadata[:,genelist].X.toarray())\n",
    "    mlppred = pd.DataFrame(mlpregr.predict(valcoord),index=valadata.obs.index,columns=genelist)\n",
    "    mlpadata = sc.AnnData(mlppred)\n",
    "    mlpadata.obsm['spatial']=valcoord\n",
    "    mlpadata.obs = valadata.obs\n",
    "    mlpadata.write_h5ad(f'{savedir}{adataname}_mlppred.h5ad')\n",
    "    print(f'{savedir}{adataname}_mlppred.h5ad')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8872f4d6-d3fa-4f86-9ebd-15d30bfba6db",
   "metadata": {},
   "source": [
    "# GP Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbb394d1-f3fe-4a4b-b0bc-8c588a335f27",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-03T14:05:44.207354Z",
     "iopub.status.busy": "2024-05-03T14:05:44.206666Z",
     "iopub.status.idle": "2024-05-06T02:50:05.978245Z",
     "shell.execute_reply": "2024-05-06T02:50:05.977328Z",
     "shell.execute_reply.started": "2024-05-03T14:05:44.207304Z"
    },
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.gaussian_process import GaussianProcessRegressor\n",
    "for adataname in alldataname:\n",
    "    adataname = adataname[:-5]\n",
    "    refadata = sc.read_h5ad(f'{savedir}{adataname}_ref.h5ad')\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    valcoord = valadata.obs[['array_col','array_row']].values\n",
    "    genelist = valadata.var.index.tolist()\n",
    "    mlpregr = GaussianProcessRegressor(random_state=0).fit(refadata.obs[['array_col','array_row']].values,refadata[:,genelist].X.toarray())\n",
    "    mlppred = pd.DataFrame(mlpregr.predict(valcoord),index=valadata.obs.index,columns=genelist)\n",
    "    mlpadata = sc.AnnData(mlppred)\n",
    "    mlpadata.obsm['spatial']=valcoord\n",
    "    mlpadata.obs = valadata.obs\n",
    "    mlpadata.write_h5ad(f'{savedir}{adataname}_gppred.h5ad')\n",
    "    print(f'{savedir}{adataname}_gppred.h5ad')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33dd8b38-9882-4832-a1d8-a7316bcdb920",
   "metadata": {},
   "source": [
    "# model test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42f7bc0d-88d8-42bc-ac10-94b55c941a56",
   "metadata": {},
   "source": [
    "## L8H8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f21f810-e62f-49e6-98c7-970b48511bdd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-01T12:22:45.362523Z",
     "iopub.status.busy": "2024-05-01T12:22:45.361840Z",
     "iopub.status.idle": "2024-05-01T12:22:48.357855Z",
     "shell.execute_reply": "2024-05-01T12:22:48.357227Z",
     "shell.execute_reply.started": "2024-05-01T12:22:45.362471Z"
    }
   },
   "outputs": [],
   "source": [
    "codebook = np.load('/data2/usr/0415_pca/codebook3000.npy')\n",
    "pca = pd.read_csv('/data2/usr/Sgeneration/MERFISH/0415_pca/meta_cells_pca.csv',index_col=0).values\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "nbrs = NearestNeighbors(n_neighbors=1).fit(pca)\n",
    "pcaproj = np.load('/data2/usr/Sgeneration/MERFISH/0415_pca/PCs.npy')\n",
    "mean = np.load('/data2/usr/Sgeneration/MERFISH/0415_pca/mean.npy')\n",
    "codebook = ((nbrs,mean,pcaproj),codebook)\n",
    "\n",
    "\n",
    "config_file = pathlib.Path('./dir/TARGET_CODE3000_MERFISH_base05_sinu_R3_corner_L8H8_sinu_5e4_v142_mouse1_noh_multimlp_2024-04-29/hparams.yaml')\n",
    "class setting( object ):\n",
    "    pass\n",
    "cfg=setting()\n",
    "if config_file.exists():\n",
    "    with config_file.open('r') as f:\n",
    "        d = yaml.unsafe_load(f)\n",
    "        for k,v in d.items():\n",
    "            setattr(cfg, k, v)\n",
    "cfg.ckpt_path = './dir/TARGET_CODE3000_MERFISH_base05_sinu_R3_corner_L8H8_sinu_5e4_v142_mouse1_noh_multimlp_2024-04-29/ckpt/ckpt_best.pt'\n",
    "\n",
    "torch.cuda.init()\n",
    "backend = cfg.backend\n",
    "compile = cfg.compile # Default True, use PyTorch 2.0 to compile the model to be faster\n",
    "gradient_accumulation_steps = cfg.gradient_accumulation_steps\n",
    "batch_size = cfg.batch_size\n",
    "block_size = cfg.block_size\n",
    "task = cfg.task\n",
    "    \n",
    "n_layer = cfg.n_layer\n",
    "n_head = cfg.n_head\n",
    "n_embd = cfg.n_embd\n",
    "bias = cfg.bias\n",
    "dropout = cfg.dropout\n",
    "train_mode = cfg.train_mode\n",
    "init_from = cfg.init_from\n",
    "device_set = cfg.device\n",
    "dtype = cfg.dtype\n",
    "data_path = cfg.data_path\n",
    "ckpt_path = cfg.ckpt_path\n",
    "infersave_path = cfg.infersave_path\n",
    "N = cfg.N\n",
    "vocab_size = cfg.vocab_size\n",
    "torch.manual_seed(19491001)\n",
    "ckpt = torch.load(ckpt_path)\n",
    "print(f'load cfg.ckpt_path:{cfg.ckpt_path}')\n",
    "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
    "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
    "device_type = 'cuda'\n",
    "\n",
    "# model init\n",
    "from model import DaoConfig, GeST\n",
    "model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=4096,batch_size=batch_size,\n",
    "                    bias=bias,dropout=dropout,train_mode=train_mode,task=task,vocab_size=vocab_size,loss_len=cfg.loss_len,\n",
    "                    encoder = cfg.encoder, decoder = cfg.decoder,\n",
    "                    skipconnect=cfg.skipconnect, noise = cfg.noise, rope_base = cfg.rope_base,loc_emb = cfg.loc_emb,device_type=cfg.device,modeltype=cfg.model,\n",
    "                 codebook = cfg.codebook) # start with model_args from command line\n",
    "gptconf = DaoConfig(**model_args)\n",
    "model = GeST(gptconf)\n",
    "model.load_state_dict(ckpt['model'])\n",
    "model.to(device_type)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94f93569-6898-4bba-a4d3-4743f1355217",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-04-29T16:41:45.913001Z",
     "iopub.status.busy": "2024-04-29T16:41:45.912235Z",
     "iopub.status.idle": "2024-04-29T18:53:07.214726Z",
     "shell.execute_reply": "2024-04-29T18:53:07.214170Z",
     "shell.execute_reply.started": "2024-04-29T16:41:45.912926Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for adataname in alldataname:\n",
    "    adataname = adataname[:-5]\n",
    "    refadata = sc.read_h5ad(f'{savedir}{adataname}_ref.h5ad')\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    valcoord = valadata.obs[['array_col','array_row']].values\n",
    "    model.to(device_type)\n",
    "    model.eval()\n",
    "    expandadata,vardata,_ = expandTarget_quantize(refadata,model,valcoord,cfg,roundsize=(0.01,0.01),thres=50,verbose=False)\n",
    "    for i in range(10):\n",
    "        expandadata,vardata,_ = expandTarget_quantize(expandadata,model,valcoord,cfg,roundsize=(0.01,0.01),thres=50,verbose=False)\n",
    "    \n",
    "    valcoord = valadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "    valcoordlist = [tuple(row) for row in valcoord]\n",
    "    expandcoord = expandadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "    expandcoorddict = {tuple(coord): idx for idx, coord in enumerate(expandcoord)}\n",
    "    cidx = [expandcoorddict[c] for c in valcoordlist]\n",
    "    predval = expandadata[cidx,:].copy()\n",
    "    print(f'saving /data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_{adataname}.h5ad')\n",
    "    predval.write_h5ad(f'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_{adataname}.h5ad')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f9c34e6-28a3-4589-91e3-fd39defd1978",
   "metadata": {},
   "source": [
    " weighted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da059454-98e3-4f16-9420-b73e899829e3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-01T12:22:50.970908Z",
     "iopub.status.busy": "2024-05-01T12:22:50.970196Z",
     "iopub.status.idle": "2024-05-01T14:33:54.328188Z",
     "shell.execute_reply": "2024-05-01T14:33:54.327487Z",
     "shell.execute_reply.started": "2024-05-01T12:22:50.970851Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for adataname in alldataname:\n",
    "    adataname = adataname[:-5]\n",
    "    refadata = sc.read_h5ad(f'{savedir}{adataname}_ref.h5ad')\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    valcoord = valadata.obs[['array_col','array_row']].values\n",
    "    model.to(device_type)\n",
    "    model.eval()\n",
    "    expandadata,vardata,_ = expandTarget_quantize(refadata,model,valcoord,cfg,roundsize=(0.01,0.01),thres=50,verbose=False,mode='weighted')\n",
    "    for i in range(10):\n",
    "        expandadata,vardata,_ = expandTarget_quantize(expandadata,model,valcoord,cfg,roundsize=(0.01,0.01),thres=50,verbose=False,mode='weighted')\n",
    "    \n",
    "    valcoord = valadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "    valcoordlist = [tuple(row) for row in valcoord]\n",
    "    expandcoord = expandadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "    expandcoorddict = {tuple(coord): idx for idx, coord in enumerate(expandcoord)}\n",
    "    cidx = [expandcoorddict[c] for c in valcoordlist]\n",
    "    predval = expandadata[cidx,:].copy()\n",
    "    print(f'saving /data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_weighted_{adataname}.h5ad')\n",
    "    predval.write_h5ad(f'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_weighted_{adataname}.h5ad')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad7a632d-4382-41bf-8bd7-87a5fc5aba4b",
   "metadata": {},
   "source": [
    "## L16H16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fba34da-f48c-4ff6-a6c7-8b033cd70d75",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-01T14:33:54.329999Z",
     "iopub.status.busy": "2024-05-01T14:33:54.329752Z",
     "iopub.status.idle": "2024-05-01T14:34:03.371201Z",
     "shell.execute_reply": "2024-05-01T14:34:03.370727Z",
     "shell.execute_reply.started": "2024-05-01T14:33:54.329978Z"
    }
   },
   "outputs": [],
   "source": [
    "codebook = np.load('/data2/usr/0415_pca/codebook3000.npy')\n",
    "pca = pd.read_csv('/data2/usr/Sgeneration/MERFISH/0415_pca/meta_cells_pca.csv',index_col=0).values\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "nbrs = NearestNeighbors(n_neighbors=1).fit(pca)\n",
    "pcaproj = np.load('/data2/usr/Sgeneration/MERFISH/0415_pca/PCs.npy')\n",
    "mean = np.load('/data2/usr/Sgeneration/MERFISH/0415_pca/mean.npy')\n",
    "codebook = ((nbrs,mean,pcaproj),codebook)\n",
    "\n",
    "\n",
    "config_file = pathlib.Path('./dir/TARGET_CODE3000_MERFISH_base05_sinu_R3_corner_L16H16_sinu_5e4_v142_mouse1_hierarchy_multimlp_continue_2024-04-29/hparams.yaml')\n",
    "class setting( object ):\n",
    "    pass\n",
    "cfg=setting()\n",
    "if config_file.exists():\n",
    "    with config_file.open('r') as f:\n",
    "        d = yaml.unsafe_load(f)\n",
    "        for k,v in d.items():\n",
    "            setattr(cfg, k, v)\n",
    "cfg.ckpt_path = './dir/TARGET_CODE3000_MERFISH_base05_sinu_R3_corner_L16H16_sinu_5e4_v142_mouse1_hierarchy_multimlp_continue_2024-04-29/ckpt/ckpt_epoch1_3321.pt'\n",
    "\n",
    "torch.cuda.init()\n",
    "backend = cfg.backend\n",
    "compile = cfg.compile # Default True, use PyTorch 2.0 to compile the model to be faster\n",
    "gradient_accumulation_steps = cfg.gradient_accumulation_steps\n",
    "batch_size = cfg.batch_size\n",
    "block_size = cfg.block_size\n",
    "task = cfg.task\n",
    "    \n",
    "n_layer = cfg.n_layer\n",
    "n_head = cfg.n_head\n",
    "n_embd = cfg.n_embd\n",
    "bias = cfg.bias\n",
    "dropout = cfg.dropout\n",
    "train_mode = cfg.train_mode\n",
    "init_from = cfg.init_from\n",
    "device_set = cfg.device\n",
    "dtype = cfg.dtype\n",
    "data_path = cfg.data_path\n",
    "ckpt_path = cfg.ckpt_path\n",
    "infersave_path = cfg.infersave_path\n",
    "N = cfg.N\n",
    "vocab_size = cfg.vocab_size\n",
    "torch.manual_seed(19491001)\n",
    "ckpt = torch.load(ckpt_path)\n",
    "print(f'load cfg.ckpt_path:{cfg.ckpt_path}')\n",
    "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
    "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
    "device_type = 'cuda'\n",
    "\n",
    "# model init\n",
    "from model import DaoConfig, GeST\n",
    "model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=4096,batch_size=batch_size,\n",
    "                    bias=bias,dropout=dropout,train_mode=train_mode,task=task,vocab_size=vocab_size,loss_len=cfg.loss_len,\n",
    "                    encoder = cfg.encoder, decoder = cfg.decoder,\n",
    "                    skipconnect=cfg.skipconnect, noise = cfg.noise, rope_base = cfg.rope_base,loc_emb = cfg.loc_emb,device_type=cfg.device,modeltype=cfg.model,\n",
    "                 codebook = cfg.codebook) # start with model_args from command line\n",
    "gptconf = DaoConfig(**model_args)\n",
    "model = GeST(gptconf)\n",
    "model.load_state_dict(ckpt['model'])\n",
    "model.to(device_type)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16114808-de11-4617-8c09-f95dfe36bd54",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-04-29T18:53:11.110509Z",
     "iopub.status.busy": "2024-04-29T18:53:11.110346Z",
     "iopub.status.idle": "2024-04-29T21:38:15.822970Z",
     "shell.execute_reply": "2024-04-29T21:38:15.822421Z",
     "shell.execute_reply.started": "2024-04-29T18:53:11.110492Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for adataname in alldataname:\n",
    "    adataname = adataname[:-5]\n",
    "    refadata = sc.read_h5ad(f'{savedir}{adataname}_ref.h5ad')\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    valcoord = valadata.obs[['array_col','array_row']].values\n",
    "    model.to(device_type)\n",
    "    model.eval()\n",
    "    expandadata,vardata,_ = expandTarget_quantize(refadata,model,valcoord,cfg,roundsize=(0.01,0.01),thres=50,verbose=False)\n",
    "    for i in range(10):\n",
    "        expandadata,vardata,_ = expandTarget_quantize(expandadata,model,valcoord,cfg,roundsize=(0.01,0.01),thres=50,verbose=False)\n",
    "    \n",
    "    valcoord = valadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "    valcoordlist = [tuple(row) for row in valcoord]\n",
    "    expandcoord = expandadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "    expandcoorddict = {tuple(coord): idx for idx, coord in enumerate(expandcoord)}\n",
    "    cidx = [expandcoorddict[c] for c in valcoordlist]\n",
    "    predval = expandadata[cidx,:].copy()\n",
    "    print(f'saving /data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_L16H16_{adataname}.h5ad')\n",
    "    predval.write_h5ad(f'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_L16H16_{adataname}.h5ad')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbd0d118-9cbc-4946-8138-9b741d602099",
   "metadata": {},
   "source": [
    "weighted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "283a2f70-2b7a-46e1-9d1a-4065a163c3de",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-01T14:34:03.372197Z",
     "iopub.status.busy": "2024-05-01T14:34:03.371949Z",
     "iopub.status.idle": "2024-05-01T17:19:05.440286Z",
     "shell.execute_reply": "2024-05-01T17:19:05.439453Z",
     "shell.execute_reply.started": "2024-05-01T14:34:03.372172Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for adataname in alldataname:\n",
    "    adataname = adataname[:-5]\n",
    "    refadata = sc.read_h5ad(f'{savedir}{adataname}_ref.h5ad')\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    valcoord = valadata.obs[['array_col','array_row']].values\n",
    "    model.to(device_type)\n",
    "    model.eval()\n",
    "    expandadata,vardata,_ = expandTarget_quantize(refadata,model,valcoord,cfg,roundsize=(0.01,0.01),thres=50,verbose=False,mode='weighted')\n",
    "    for i in range(10):\n",
    "        expandadata,vardata,_ = expandTarget_quantize(expandadata,model,valcoord,cfg,roundsize=(0.01,0.01),thres=50,verbose=False,mode='weighted')\n",
    "    \n",
    "    valcoord = valadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "    valcoordlist = [tuple(row) for row in valcoord]\n",
    "    expandcoord = expandadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "    expandcoorddict = {tuple(coord): idx for idx, coord in enumerate(expandcoord)}\n",
    "    cidx = [expandcoorddict[c] for c in valcoordlist]\n",
    "    predval = expandadata[cidx,:].copy()\n",
    "    print(f'saving /data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_L16H16_weigthed_{adataname}.h5ad')\n",
    "    predval.write_h5ad(f'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_L16H16_weigthed_{adataname}.h5ad')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfea9bd0-80af-49e4-a48c-1e64108b327f",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# SVG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04e07036-506d-41d3-8098-0f7a940c1790",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import somde\n",
    "from somde import SomNode\n",
    "for adataname in alldataname:\n",
    "    adataname = adataname[:-5]\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    valcoord = valadata.obs[['array_col','array_row']].values\n",
    "    valexp = pd.DataFrame(valadata.X,index=valadata.obs.index,columns=valadata.var.index)\n",
    "    som = SomNode(valcoord.astype(float),2)\n",
    "    cutvalexp = valexp.loc[:,(valexp>0.1).sum(0)>50]\n",
    "    ndf,ninfo = som.mtx(cutvalexp.T)\n",
    "    som.nres =ndf.T\n",
    "    result, SVnum =som.run()\n",
    "    result.to_csv(f'{savedir}{adataname}_query_svg.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "998e4227-787e-48d9-b7e0-6616de283198",
   "metadata": {},
   "source": [
    "# Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "75f83b30-c1b3-477b-a25d-de98fdbeb080",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T02:45:00.759177Z",
     "iopub.status.busy": "2024-06-18T02:45:00.758443Z",
     "iopub.status.idle": "2024-06-18T02:45:00.787477Z",
     "shell.execute_reply": "2024-06-18T02:45:00.786782Z",
     "shell.execute_reply.started": "2024-06-18T02:45:00.759126Z"
    }
   },
   "outputs": [],
   "source": [
    "annotationtable = pd.read_csv('/data2/usr/Sgeneration/MERFISH/Annotation/cluster_to_cluster_annotation_membership_pivoted.csv')\n",
    "annotationtable = annotationtable.set_index('cluster_alias')\n",
    "annotationcolor = pd.read_csv('/data2/usr/Sgeneration/MERFISH/Annotation/cluster_to_cluster_annotation_membership_color.csv')\n",
    "annotationcolor = annotationcolor.set_index('cluster_alias')\n",
    "annotation = pd.concat([annotationtable,annotationcolor],axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94890dc9-aaa7-49a7-b144-c6a8af588b24",
   "metadata": {},
   "source": [
    "### Gene"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "633be51d-a0f2-4a4e-9fa4-2d11c6bce803",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T08:34:34.764286Z",
     "iopub.status.busy": "2024-06-18T08:34:34.763621Z",
     "iopub.status.idle": "2024-06-18T08:38:21.093770Z",
     "shell.execute_reply": "2024-06-18T08:38:21.092769Z",
     "shell.execute_reply.started": "2024-06-18T08:34:34.764238Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for adataname in alldataname:\n",
    "    adataname = adataname[:-5]\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    mlpadata = sc.read_h5ad(f'{savedir}{adataname}_mlppred.h5ad')\n",
    "    gpadata = sc.read_h5ad(f'{savedir}{adataname}_gppred.h5ad')\n",
    "    predadata = sc.read_h5ad(f'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_{adataname}.h5ad')\n",
    "    predweightedadata = sc.read_h5ad(f'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_weighted_{adataname}.h5ad')\n",
    "    result = pd.read_csv(f'{savedir}{adataname}_query_svg.csv',index_col=0)\n",
    "    genelist = result.g.tolist()[:5]\n",
    "    query = annotation.loc[valadata.obs.cluster_alias.values,:]\n",
    "    query.index = valadata.obs.index\n",
    "    valadata.obs = pd.concat([valadata.obs, query],axis=1)\n",
    "    fig, axs = plt.subplots(5, 5, figsize=(22, 20))\n",
    "    datasets = [predadata,predweightedadata, mlpadata, gpadata, valadata]\n",
    "    titles = ['Ours','Ours-W','MLP','GP', 'GT']\n",
    "    for i, adata in enumerate(datasets):\n",
    "        for j, gene in enumerate(genelist):\n",
    "            ax = axs[i, j]\n",
    "            sc.pl.spatial(adata, color=gene, spot_size=0.02, show=False, ax=ax)\n",
    "            ax.set_title(f\"{titles[i]}: {valadata.var.loc[gene].item()}\")\n",
    "    plt.subplots_adjust(wspace=0.3, hspace=0.6)  # Adjust the spacing\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"/stor/usr/sgenetmp/results/figures/m2gene/{adataname}.png\")\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "800a50bf-4439-49a4-a1d0-ca6f3ba89555",
   "metadata": {},
   "source": [
    "### Metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7c0f95e-8b47-4125-9e2c-306e0a4536a1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T08:48:44.573818Z",
     "iopub.status.busy": "2024-06-18T08:48:44.573095Z",
     "iopub.status.idle": "2024-06-18T09:15:58.417037Z",
     "shell.execute_reply": "2024-06-18T09:15:58.416240Z",
     "shell.execute_reply.started": "2024-06-18T08:48:44.573761Z"
    }
   },
   "outputs": [],
   "source": [
    "from utils import *\n",
    "import colorbm as cbm\n",
    "allresult = []\n",
    "for adataname in tqdm(alldataname):\n",
    "    adataname = adataname[:-5]\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    mlpadata = sc.read_h5ad(f'{savedir}{adataname}_mlppred.h5ad')\n",
    "    gpadata = sc.read_h5ad(f'{savedir}{adataname}_gppred.h5ad')\n",
    "    predadata = sc.read_h5ad(f'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_{adataname}.h5ad')\n",
    "    predweightedadata = sc.read_h5ad(f'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_weighted_{adataname}.h5ad')\n",
    "    result = pd.read_csv(f'{savedir}{adataname}_query_svg.csv',index_col=0)\n",
    "\n",
    "    cellmask = (predadata.obs.status!='nomask').values\n",
    "    valexp = pd.DataFrame(valadata.X,index=valadata.obs.index,columns=valadata.var.index)\n",
    "    preddf = pd.DataFrame(predadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    predwdf = pd.DataFrame(predweightedadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    gpdf = pd.DataFrame(gpadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    mlppred = pd.DataFrame(mlpadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    \n",
    "    validgene = valexp.columns.tolist()\n",
    "    evalobj = GeneExpEval(valexp.loc[:,validgene], valexp.loc[:,validgene],'GeST','./analysis/target/Evaluate/',['SSIM','PCC','RMSE','JS'])\n",
    "    rModel = evalobj.RMSE(valexp.loc[cellmask,validgene], preddf.loc[cellmask,validgene],scale='zscore')\n",
    "    rModelW = evalobj.RMSE(valexp.loc[cellmask,validgene], predwdf.loc[cellmask,validgene],scale='zscore')\n",
    "    rgp = evalobj.RMSE(valexp.loc[cellmask,validgene], gpdf.loc[cellmask,validgene],scale='zscore')\n",
    "    rmlp = evalobj.RMSE(valexp.loc[cellmask,validgene], mlppred.loc[cellmask,validgene],scale='zscore')\n",
    "    nanmask = ~np.isnan(rModel).values[0]\n",
    "    topsvg = result.g[:200].tolist()\n",
    "    figsize(10,3)\n",
    "    plt.figure()\n",
    "    subplot(121)\n",
    "    (valexp.loc[:,validgene].loc[:,nanmask]>0).sum(0).hist(bins=100,label='Non Nan',color='black')\n",
    "    (valexp.loc[:,validgene].loc[:,~nanmask]>0).sum(0).hist(bins=50,label='Nan',color='red')\n",
    "    plt.xlabel('#Expressed cell number')\n",
    "    plt.ylabel('#Count')\n",
    "    plt.legend();\n",
    "    \n",
    "    subplot(122)\n",
    "    plt.scatter(np.arange(len(topsvg)),rModel[topsvg],label='Ours',s=10,c=cbm.pal('npg').as_hex[0])\n",
    "    plt.scatter(np.arange(len(topsvg)),rModelW[topsvg],label='Ours-W',s=10,c=cbm.pal('npg').as_hex[1])\n",
    "    plt.scatter(np.arange(len(topsvg)),rmlp[topsvg],label='MLP',s=10,c=cbm.pal('npg').as_hex[2])\n",
    "    plt.scatter(np.arange(len(topsvg)),rgp[topsvg],label='GP',s=10,c=cbm.pal('npg').as_hex[3])\n",
    "    plt.legend();\n",
    "    plt.xlabel('#SVG rank')\n",
    "    plt.ylabel(\"RMSE\");\n",
    "    plt.title(f'Top {len(topsvg)} Ours:{rModel[topsvg].mean(1).values[0]:.3f}, Ours-W:{rModelW[topsvg].mean(1).values[0]:.3f}, MLP:{rmlp[topsvg].mean(1).values[0]:.3f}, GP:{rgp[topsvg].mean(1).values[0]:.3f}')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"/stor/usr/sgenetmp/results/figures/m2rmse/RMSE_{adataname}.png\")\n",
    "    plt.close()\n",
    "\n",
    "    rModel['model']='Ours'\n",
    "    rModelW['model']='Ours-W'\n",
    "    rgp['model']='GP'\n",
    "    rmlp['model']='MLP'\n",
    "    rModel['data']=adataname\n",
    "    rModelW['data']=adataname\n",
    "    rgp['data']=adataname\n",
    "    rmlp['data']=adataname\n",
    "    allresult.extend([rModel,rModelW,rgp,rmlp])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "cdea226a-da0b-4835-9a6c-34a65a92225a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T09:15:58.418405Z",
     "iopub.status.busy": "2024-06-18T09:15:58.418114Z",
     "iopub.status.idle": "2024-06-18T09:16:10.658263Z",
     "shell.execute_reply": "2024-06-18T09:16:10.657622Z",
     "shell.execute_reply.started": "2024-06-18T09:15:58.418390Z"
    }
   },
   "outputs": [],
   "source": [
    "resultdf = pd.concat(allresult,axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "7679ee86-f8c8-4325-b7a7-a91020a5fd87",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T09:16:10.660177Z",
     "iopub.status.busy": "2024-06-18T09:16:10.659944Z",
     "iopub.status.idle": "2024-06-18T09:16:11.013740Z",
     "shell.execute_reply": "2024-06-18T09:16:11.013174Z",
     "shell.execute_reply.started": "2024-06-18T09:16:10.660161Z"
    }
   },
   "outputs": [],
   "source": [
    "resultdf.to_csv('/stor/usr/sgenetmp/results/figures/Mouse2prediction0618.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "7f202bba-b42b-4f42-8d68-7e25426db8fa",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "rcParams['pdf.fonttype'] = 42\n",
    "rcParams['ps.fonttype'] = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ec3b1beb-dd79-48ac-93b1-96f235a90fc5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-14T18:41:28.757417Z",
     "iopub.status.busy": "2024-09-14T18:41:28.756115Z",
     "iopub.status.idle": "2024-09-14T18:41:31.258972Z",
     "shell.execute_reply": "2024-09-14T18:41:31.257819Z",
     "shell.execute_reply.started": "2024-09-14T18:41:28.757351Z"
    }
   },
   "outputs": [],
   "source": [
    "resultdf = pd.read_csv('/stor/usr/sgenetmp/results/figures/Mouse2prediction0618.csv')\n",
    "resultdf = resultdf.iloc[:,1:]\n",
    "\n",
    "resultdf['meanRMSE'] = resultdf.iloc[:,:-2].mean(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32146f72-a917-4b19-be01-9d4467538ca9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-14T18:42:26.055059Z",
     "iopub.status.busy": "2024-09-14T18:42:26.054480Z",
     "iopub.status.idle": "2024-09-14T18:42:26.070727Z",
     "shell.execute_reply": "2024-09-14T18:42:26.068745Z",
     "shell.execute_reply.started": "2024-09-14T18:42:26.055017Z"
    }
   },
   "outputs": [],
   "source": [
    "resultdf.groupby('model')['meanRMSE'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0114f0fd-9993-4ac6-baba-375ca7ab1a06",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "resultdf = pd.read_csv('/stor/usr/sgenetmp/results/figures/Mouse2prediction0618.csv')\n",
    "resultdf = resultdf.iloc[:,1:]\n",
    "\n",
    "resultdf['meanRMSE'] = resultdf.iloc[:,:-2].mean(1)\n",
    "\n",
    "figsize(5,3)\n",
    "sns.set_palette(sns.color_palette(cbm.pal('npg').as_hex))\n",
    "sns.boxplot(data = resultdf,x='model',y='meanRMSE',hue='model',saturation=0.5)\n",
    "sns.stripplot(data = resultdf,x='model',y='meanRMSE',size=5,jitter=True,hue='model',edgecolor='black',linewidth=0.2)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"/stor/usr/sgenetmp/results/figures/RMSEALL.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "8b04ee9d-a3a9-4f90-bf81-fa292285e005",
   "metadata": {},
   "outputs": [],
   "source": [
    "resultdf = pd.read_csv('/stor/usr/sgenetmp/results/figures/Mouse2prediction0618.csv')\n",
    "resultdf = resultdf.iloc[:,1:]\n",
    "resultdf['topSVG_RMSE']=0\n",
    "\n",
    "for adataname in alldataname:\n",
    "    adataname = adataname[:-5]\n",
    "    result = pd.read_csv(f'{savedir}{adataname}_query_svg.csv',index_col=0)\n",
    "    resultdf.loc[resultdf['data']==adataname,'topSVG_RMSE']= resultdf.loc[resultdf['data']==adataname,result.g[:200].tolist()].mean(1).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d33ac04-e84b-414b-8f96-7c79c7416cf4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import colorbm as cbm\n",
    "figsize(5,3)\n",
    "sns.set_palette(sns.color_palette(cbm.pal('npg').as_hex))\n",
    "sns.boxplot(data = resultdf,x='model',y='topSVG_RMSE',hue='model',saturation=0.5)\n",
    "sns.stripplot(data = resultdf,x='model',y='topSVG_RMSE',size=5,jitter=True,hue='model',edgecolor='black',linewidth=0.2)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"/stor/usr/sgenetmp/results/figures/RMSEALL_top200.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa6b11ba-953e-4a11-a48c-37c39d65d0f5",
   "metadata": {},
   "source": [
    "### Spearman"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "058eb3b8-132d-479c-851e-a76e341b2bc4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T09:19:14.759093Z",
     "iopub.status.busy": "2024-06-18T09:19:14.758612Z",
     "iopub.status.idle": "2024-06-18T09:20:56.182508Z",
     "shell.execute_reply": "2024-06-18T09:20:56.181608Z",
     "shell.execute_reply.started": "2024-06-18T09:19:14.759048Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from utils import *\n",
    "allresult = []\n",
    "for adataname in tqdm(alldataname):\n",
    "    adataname = adataname[:-5]\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    mlpadata = sc.read_h5ad(f'{savedir}{adataname}_mlppred.h5ad')\n",
    "    gpadata = sc.read_h5ad(f'{savedir}{adataname}_gppred.h5ad')\n",
    "    predadata = sc.read_h5ad(f'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_{adataname}.h5ad')\n",
    "    predwadata = sc.read_h5ad(f'/data2/usr/Sgeneration/MERFISH/Zhuang-ABCA-2/processed/Sgeneration_zeroshot_pred/val_pred_weighted_{adataname}.h5ad')\n",
    "    result = pd.read_csv(f'{savedir}{adataname}_query_svg.csv',index_col=0)\n",
    "\n",
    "    cellmask = (predadata.obs.status!='nomask').values\n",
    "    valexp = pd.DataFrame(valadata.X,index=valadata.obs.index,columns=valadata.var.index)\n",
    "    preddf = pd.DataFrame(predadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    predwdf = pd.DataFrame(predwadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    mlppred = pd.DataFrame(mlpadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    gppred = pd.DataFrame(gpadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    \n",
    "    validgene = valexp.columns.tolist()\n",
    "    evalobj = GeneExpEval(valexp.loc[:,validgene], valexp.loc[:,validgene],'GeST','./analysis/target/Evaluate/',['SSIM','PCC','RMSE','JS'])\n",
    "    rModel = evalobj.sepearman_cell(valexp.loc[cellmask,validgene], preddf.loc[cellmask,validgene])\n",
    "    rModelW = evalobj.sepearman_cell(valexp.loc[cellmask,validgene], predwdf.loc[cellmask,validgene])\n",
    "    rmlp = evalobj.sepearman_cell(valexp.loc[cellmask,validgene], mlppred.loc[cellmask,validgene])\n",
    "    rgp = evalobj.sepearman_cell(valexp.loc[cellmask,validgene], gppred.loc[cellmask,validgene])\n",
    "    rModel['model']='Ours'\n",
    "    rModelW['model']='Ours-W'\n",
    "    rmlp['model']='MLP'\n",
    "    rgp['model']='GP'\n",
    "    rModel['data']=adataname\n",
    "    rModelW['data']=adataname\n",
    "    rmlp['data']=adataname\n",
    "    rgp['data']=adataname\n",
    "    allresult.extend([rModel,rModelW,rmlp,rgp])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "58a6b3e5-cf06-4053-921f-0858663a5036",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T09:20:56.184399Z",
     "iopub.status.busy": "2024-06-18T09:20:56.184023Z",
     "iopub.status.idle": "2024-06-18T09:20:56.517031Z",
     "shell.execute_reply": "2024-06-18T09:20:56.516478Z",
     "shell.execute_reply.started": "2024-06-18T09:20:56.184373Z"
    }
   },
   "outputs": [],
   "source": [
    "resultdf = pd.concat(allresult,axis=0)\n",
    "resultdf.to_csv('/stor/usr/sgenetmp/results/figures/Mouse2prediction0618_Spearman.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8990e31e-8770-4404-a0fd-fef5f19fc6dc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "figsize(5,3)\n",
    "resultdf = pd.read_csv('/stor/usr/sgenetmp/results/figures/Mouse2prediction0618_Spearman.csv')\n",
    "resultdf = resultdf.iloc[:,1:]\n",
    "resultdf.rename(columns={'Unnamed: 1':'Spearman'},inplace=True)\n",
    "average_df = resultdf.groupby(['model', 'data'])['Spearman'].mean().reset_index()\n",
    "sns.boxplot(data = average_df,x='model',y='Spearman',hue='model',saturation=0.7,order=['Ours','Ours-W','GP','MLP'],hue_order=['Ours','Ours-W','GP','MLP'])\n",
    "sns.stripplot(data = average_df,x='model',y='Spearman',size=5,jitter=True,hue='model',edgecolor='black',linewidth=0.2,order=['Ours','Ours-W','GP','MLP'],hue_order=['Ours','Ours-W','GP','MLP'])\n",
    "plt.savefig(f\"/stor/usr/sgenetmp/results/figures/SpearMANALL.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cb6c4c5-bd43-4590-befd-82bbd913a7b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "import colorbm as cbm\n",
    "allresult = []\n",
    "for adataname in tqdm(['Zhuang-ABCA-2.045.h5ad', 'Zhuang-ABCA-2.002.h5ad']):\n",
    "    adataname = adataname[:-5]\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    mlpadata = sc.read_h5ad(f'{savedir}{adataname}_mlppred.h5ad')\n",
    "    gpadata = sc.read_h5ad(f'{savedir}{adataname}_gppred.h5ad')\n",
    "    predadata = sc.read_h5ad(f'{savedir}Sgeneration_zeroshot_pred/val_pred_{adataname}.h5ad')\n",
    "    predweightedadata = sc.read_h5ad(f'{savedir}Sgeneration_zeroshot_pred/val_pred_weighted_{adataname}.h5ad')\n",
    "    result = pd.read_csv(f'{savedir}{adataname}_query_svg.csv',index_col=0)\n",
    "\n",
    "    cellmask = (predadata.obs.status!='nomask').values\n",
    "    valexp = pd.DataFrame(valadata.X,index=valadata.obs.index,columns=valadata.var.index)\n",
    "    preddf = pd.DataFrame(predadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    predwdf = pd.DataFrame(predweightedadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    gpdf = pd.DataFrame(gpadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    mlppred = pd.DataFrame(mlpadata.X,index=valexp.index,columns=valexp.columns)\n",
    "    \n",
    "    validgene = valexp.columns.tolist()\n",
    "    evalobj = GeneExpEval(valexp.loc[:,validgene], valexp.loc[:,validgene],'GeST','./analysis/target/Evaluate/',['SSIM','PCC','RMSE','JS'])\n",
    "    rModel = evalobj.RMSE(valexp.loc[cellmask,validgene], preddf.loc[cellmask,validgene],scale='zscore')\n",
    "    rModelW = evalobj.RMSE(valexp.loc[cellmask,validgene], predwdf.loc[cellmask,validgene],scale='zscore')\n",
    "    rgp = evalobj.RMSE(valexp.loc[cellmask,validgene], gpdf.loc[cellmask,validgene],scale='zscore')\n",
    "    rmlp = evalobj.RMSE(valexp.loc[cellmask,validgene], mlppred.loc[cellmask,validgene],scale='zscore')\n",
    "    nanmask = ~np.isnan(rModel).values[0]\n",
    "    topsvg = result.g[:200].tolist()\n",
    "    figsize(10,3)\n",
    "    plt.figure()\n",
    "    subplot(121)\n",
    "    (valexp.loc[:,validgene].loc[:,nanmask]>0).sum(0).hist(bins=100,label='Non Nan',color='black')\n",
    "    (valexp.loc[:,validgene].loc[:,~nanmask]>0).sum(0).hist(bins=50,label='Nan',color='red')\n",
    "    plt.xlabel('#Expressed cell number')\n",
    "    plt.ylabel('#Count')\n",
    "    plt.legend();\n",
    "    \n",
    "    subplot(122)\n",
    "    plt.scatter(np.arange(len(topsvg)),rModel[topsvg],label='Ours',s=10,c=cbm.pal('npg').as_hex[0])\n",
    "    plt.scatter(np.arange(len(topsvg)),rModelW[topsvg],label='Ours-W',s=10,c=cbm.pal('npg').as_hex[1])\n",
    "    plt.scatter(np.arange(len(topsvg)),rmlp[topsvg],label='MLP',s=10,c=cbm.pal('npg').as_hex[2])\n",
    "    plt.scatter(np.arange(len(topsvg)),rgp[topsvg],label='GP',s=10,c=cbm.pal('npg').as_hex[3])\n",
    "    plt.legend();\n",
    "    plt.xlabel('#SVG rank')\n",
    "    plt.ylabel(\"RMSE\");\n",
    "    plt.title(f'Top {len(topsvg)} Ours:{rModel[topsvg].mean(1).values[0]:.3f}, Ours-W:{rModelW[topsvg].mean(1).values[0]:.3f}, MLP:{rmlp[topsvg].mean(1).values[0]:.3f}, GP:{rgp[topsvg].mean(1).values[0]:.3f}')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"/stor/usr/sgenetmp/results/figures/m2rmse/RMSE_{adataname}.pdf\")\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "185511fb-06d4-444c-a401-b8f1ec8d29dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "rcParams['pdf.fonttype'] = 42\n",
    "rcParams['ps.fonttype'] = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b4002e9-43f7-4136-adf7-8fe862ff9fc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for adataname in tqdm(['Zhuang-ABCA-2.045.h5ad', 'Zhuang-ABCA-2.002.h5ad']):\n",
    "    adataname = adataname[:-5]\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    mlpadata = sc.read_h5ad(f'{savedir}{adataname}_mlppred.h5ad')\n",
    "    gpadata = sc.read_h5ad(f'{savedir}{adataname}_gppred.h5ad')\n",
    "    predadata = sc.read_h5ad(f'{savedir}Sgeneration_zeroshot_pred/val_pred_{adataname}.h5ad')\n",
    "    predweightedadata = sc.read_h5ad(f'{savedir}Sgeneration_zeroshot_pred/val_pred_weighted_{adataname}.h5ad')\n",
    "    result = pd.read_csv(f'{savedir}{adataname}_query_svg.csv',index_col=0)\n",
    "    genelist = result.g.tolist()[:5]\n",
    "    fig, axs = plt.subplots(5, 5, figsize=(22, 18))\n",
    "    datasets = [valadata,predadata,predweightedadata, mlpadata, gpadata]\n",
    "    titles = ['GT','Ours','Ours-W','MLP','GP']\n",
    "    for i, adata in enumerate(datasets):\n",
    "        for j, gene in enumerate(genelist):\n",
    "            ax = axs[j, i]\n",
    "            sc.pl.spatial(adata, color=gene, spot_size=0.02, show=False, ax=ax)\n",
    "            ax.set_title(\"\")\n",
    "            ax.set_xlabel(\"\")\n",
    "            ax.set_ylabel(\"\")\n",
    "            # Access the colorbar\n",
    "            cbar = ax.collections[0].colorbar\n",
    "            cbar.ax.tick_params(labelsize=15)  # Set colorbar tick font size\n",
    "            cbar.update_ticks()  # Update the colorbar ticks with the new format\n",
    "    plt.subplots_adjust(wspace=0.6, hspace=0.4)  # Adjust the spacing\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"/stor/usr/sgenetmp/results/figures/m2gene/{adataname}.pdf\")\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8944d353-62e5-4645-af88-904cb2ba9eff",
   "metadata": {},
   "outputs": [],
   "source": [
    "for adataname in tqdm(['Zhuang-ABCA-2.045.h5ad', 'Zhuang-ABCA-2.002.h5ad']):\n",
    "    adataname = adataname[:-5]\n",
    "    valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')\n",
    "    rawadata = sc.read_h5ad(f'{basedir}{adataname}.h5ad')\n",
    "\n",
    "    rawadata.obs['split']='Ref'\n",
    "    rawadata.obs.loc[rawadata.obs.index.isin(valadata.obs.index),'split']='Unknown'\n",
    "\n",
    "    figsize(4,5)\n",
    "    ax =sc.pl.spatial(rawadata,color='split',spot_size=0.03,palette=['black','red'],title=f'Slide-{adataname[-2:]}',return_fig=True,show=False)\n",
    "    ax = ax[0]\n",
    "    ax.set_xlabel(\"\")\n",
    "    ax.set_ylabel(\"\")\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"/stor/usr/sgenetmp/results/figures/{adataname}_split.pdf\")\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a68b760-b824-4e60-9cbe-270e791eee20",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pt",
   "language": "python",
   "name": "pt"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
