{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a96d17b-3802-4423-8103-b1a5acd75c3f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%pylab inline\n",
    "import scanpy as sc\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import scipy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a609b1f-9d90-40e3-975d-e6a1bdbc35d5",
   "metadata": {},
   "source": [
    "# Load Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "966493ac-926b-4318-861d-d0e4e9eb8536",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "os.chdir('../../')\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0'\n",
    "os.environ['LOGURU_LEVEL']='INFO'\n",
    "import yaml\n",
    "import pathlib\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from dataset import SpatialSeq\n",
    "from inference import *\n",
    "\n",
    "def seed_all(seed, cuda_deterministic=False):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(seed)\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    if cuda_deterministic: # slower, more reproducible\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        torch.backends.cudnn.benchmark = False\n",
    "    else:\n",
    "        torch.backends.cudnn.deterministic = False\n",
    "        torch.backends.cudnn.benchmark = True\n",
    "seed_all(19491001,cuda_deterministic=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb1da451-0858-4acb-a684-df1e70716bcb",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7323bc49-28a7-46f8-b97d-6d64ef2034d9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from dataset import MultiSpatialTarget,SpatialTarget\n",
    "from inference import expandTarget_quantize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "007b6bdd-11cb-4687-80bb-7e38474e9c1e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "adatahvg_raw = sc.read_h5ad(\"/stor/usr/sgenetmp/perturb/ischemic_49.h5ad\")\n",
    "valadata = adatahvg_raw[adatahvg_raw.obs['mask']=='False'].copy()\n",
    "adatahvg = adatahvg_raw[adatahvg_raw.obs['mask']=='True'].copy()\n",
    "valcoord = valadata.obs[['array_col', 'array_row']].values.astype(float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba4fd2dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "adatahvg_raw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f67ca8ea-5fb3-42fb-a834-3be7b7d939c8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sc.pl.spatial(adatahvg_raw,spot_size=0.03,color=['class'])\n",
    "sc.pl.spatial(valadata,spot_size=0.03,color=['class'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03c476f9-d21c-4af7-ac75-113189a80ea9",
   "metadata": {},
   "source": [
    "# Generate Control group"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fceccb0-add2-4379-8c6f-0b69c5184e4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_file = pathlib.Path('/stor/usr/sgenetmp/TARGET_CODE3000_MERFISH_base05_sinu_R3_corner_L8H8_sinu_5e4_v142_mouse1_noh_multimlp_2024-04-29/hparams.yaml')\n",
    "class setting( object ):\n",
    "    pass\n",
    "cfg=setting()\n",
    "if config_file.exists():\n",
    "    with config_file.open('r') as f:\n",
    "        d = yaml.unsafe_load(f)\n",
    "        for k,v in d.items():\n",
    "            setattr(cfg, k, v)\n",
    "\n",
    "cfg.ckpt_path = '/stor/usr/sgenetmp/TARGET_CODE3000_MERFISH_base05_sinu_R3_corner_L8H8_sinu_5e4_v142_mouse1_noh_multimlp_2024-04-29/ckpt/ckpt_best.pt'\n",
    "\n",
    "torch.cuda.init()\n",
    "backend = cfg.backend\n",
    "compile = cfg.compile # Default True, use PyTorch 2.0 to compile the model to be faster\n",
    "gradient_accumulation_steps = cfg.gradient_accumulation_steps\n",
    "batch_size = cfg.batch_size\n",
    "block_size = cfg.block_size\n",
    "task = cfg.task\n",
    "    \n",
    "n_layer = cfg.n_layer\n",
    "n_head = cfg.n_head\n",
    "n_embd = cfg.n_embd\n",
    "bias = cfg.bias\n",
    "dropout = cfg.dropout\n",
    "train_mode = cfg.train_mode\n",
    "init_from = cfg.init_from\n",
    "device_set = cfg.device\n",
    "dtype = cfg.dtype\n",
    "data_path = cfg.data_path\n",
    "ckpt_path = cfg.ckpt_path\n",
    "infersave_path = cfg.infersave_path\n",
    "N = cfg.N\n",
    "vocab_size = cfg.vocab_size\n",
    "torch.manual_seed(19491001)\n",
    "ckpt = torch.load(ckpt_path)\n",
    "print(f'load cfg.ckpt_path:{cfg.ckpt_path}')\n",
    "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
    "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
    "device_type = 'cuda'\n",
    "\n",
    "# model init\n",
    "from model import DaoConfig, GeST\n",
    "model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=4096,batch_size=batch_size,\n",
    "                    bias=bias,dropout=dropout,train_mode=train_mode,task=task,vocab_size=vocab_size,loss_len=cfg.loss_len,\n",
    "                    encoder = cfg.encoder, decoder = cfg.decoder,\n",
    "                    skipconnect=cfg.skipconnect, noise = cfg.noise, rope_base = cfg.rope_base,loc_emb = cfg.loc_emb,device_type=cfg.device,modeltype=cfg.model,\n",
    "                 codebook = cfg.codebook) \n",
    "\n",
    "gptconf = DaoConfig(**model_args)\n",
    "model = GeST(gptconf)\n",
    "\n",
    "model.load_state_dict(ckpt['model'])\n",
    "model.to(device_type)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4ecfac3-b4ae-4ba0-9622-9951f166474c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "cfg.round0,cfg.round1,cfg.task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "671f2bb5-4d16-49fa-828a-bf786f7bcf41",
   "metadata": {},
   "outputs": [],
   "source": [
    "codebook = np.load(cfg.codebook)\n",
    "basedir = cfg.codebook.split('codebook')[0]\n",
    "pca = pd.read_csv(f'{basedir}meta_cells_pca.csv',index_col=0).values\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "nbrs = NearestNeighbors(n_neighbors=1, algorithm='auto').fit(pca)\n",
    "pcaproj = np.load(f'{basedir}PCs.npy')\n",
    "mean = np.load(f'{basedir}mean.npy')\n",
    "codebook = ((nbrs,mean,pcaproj),codebook)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee36e18a-c736-4ff0-8e18-313b1f966ae5",
   "metadata": {},
   "outputs": [],
   "source": [
    "figsize(4,4)\n",
    "expandadata,_,_ = expandTarget_quantize(adatahvg,model,valcoord,cfg,roundsize=(0.03,0.03),thres=50,verbose=True,mode='weighted')\n",
    "for i in range(1):\n",
    "    expandadata,_,_ = expandTarget_quantize(expandadata,model,valcoord,cfg,roundsize=(0.03,0.03),thres=50,verbose=True,mode='weighted')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac4df7c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "rawcoord = adatahvg_raw.obs[['array_col', 'array_row']].values.astype(float)\n",
    "expandcoord = expandadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "rawcoorddict = {tuple(coord): idx for idx, coord in enumerate(rawcoord)}\n",
    "expandcoorddict = {tuple(coord): idx for idx, coord in enumerate(expandcoord)}\n",
    "commoncoord = set([tuple(row) for row in rawcoord]).intersection(set([tuple(row) for row in expandcoord]))\n",
    "cidx = [rawcoorddict[c] for c in commoncoord]\n",
    "adatahvg_gt = adatahvg_raw[cidx,:].copy()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21c79574-b0e2-4081-aa53-eb10700028ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "genelist = ['ENSMUSG00000005716', 'ENSMUSG00000005087', 'ENSMUSG00000020932', 'ENSMUSG00000027270']\n",
    "sc.pl.spatial(expandadata,spot_size=0.03,color= genelist)\n",
    "sc.pl.spatial(adatahvg_gt,spot_size=0.03,color= genelist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2028cace-559f-485e-b473-8c9c46cbc778",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "valcoord = valadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "expandcoord = expandadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "valcoorddict = {tuple(coord): idx for idx, coord in enumerate(valcoord)}\n",
    "expandcoorddict = {tuple(coord): idx for idx, coord in enumerate(expandcoord)}\n",
    "commoncoord = set([tuple(row) for row in valcoord]).intersection(set([tuple(row) for row in expandcoord]))\n",
    "cvidx = [valcoorddict[c] for c in commoncoord]\n",
    "cidx = [expandcoorddict[c] for c in commoncoord]\n",
    "predval = expandadata[cidx,:].copy()\n",
    "subval = valadata[cvidx,:].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dc12706-4023-4bf2-ba30-444b883ae580",
   "metadata": {},
   "outputs": [],
   "source": [
    "import somde\n",
    "from somde import SomNode\n",
    "valexp = pd.DataFrame(adatahvg_gt.X,index=adatahvg_gt.obs.index,columns=adatahvg_gt.var.index)\n",
    "som = SomNode(adatahvg_gt.obs[['array_col', 'array_row']].values.astype(float),5)\n",
    "ndf,ninfo = som.mtx(valexp.T)\n",
    "som.nres =ndf.T\n",
    "result, SVnum =som.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4365c688-5eda-4714-bea1-8d4dd2c4c73e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "genelist = result.g.values[:5]\n",
    "sc.pl.spatial(predval,color=genelist,spot_size=0.03,ncols=5)\n",
    "sc.pl.spatial(subval,color=genelist,spot_size=0.03,ncols=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47ea22c4-e7bb-4d41-bc05-6c6abb3d52da",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'/data1/usr/results/perturb_control/pred49_{config_file.parts[1]}_control_weigthed.h5ad')\n",
    "predval.write_h5ad(f'/data1/usr/results/perturb_control/pred49_{config_file.parts[1]}_control_weigthed.h5ad')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b5a6377-d0cf-445d-9dfe-8a8105646031",
   "metadata": {},
   "source": [
    "# Spatial Perturbation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b5b5722",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SpatialPerturbation:\n",
    "    def __init__(self, perturb_mode='up'):\n",
    "        self.perturb_mode = perturb_mode\n",
    "        if self.perturb_mode not in ['up', 'down']:\n",
    "            raise ValueError(\"The perturb_mode is not supported\")\n",
    "    \n",
    "    def perturb_gaussian(self, adata, genes, bandwidth=1.0, center_point=None):\n",
    "        coords = adata.obs[['array_col', 'array_row']].values\n",
    "        if center_point is None:\n",
    "            center_point = np.mean(coords, axis=0)\n",
    "        \n",
    "        weights = np.exp(-((coords[:,0] - center_point[0])**2 + (coords[:,1] - center_point[1])**2) / (2 * bandwidth**2))\n",
    "        if self.perturb_mode == 'up':\n",
    "            adata.X[:, adata.var.index.isin(genes)] = adata.X.max() * weights[:, np.newaxis]\n",
    "        elif self.perturb_mode == 'down':\n",
    "            adata.X[:, adata.var.index.isin(genes)] = adata.X[:, adata.var.index.isin(genes)] * (1-weights[:, np.newaxis])\n",
    "      \n",
    "    \n",
    "    def perturb_tophat(self, adata, genes, bandwidth=1.0, center_point=None):\n",
    "        coords = adata.obs[['array_col', 'array_row']].values\n",
    "        if center_point is None:\n",
    "            center_point = np.mean(coords, axis=0)\n",
    "        \n",
    "        weights = np.ones(coords.shape[0]) * ((coords[:,0] - center_point[0])**2 + (coords[:,1] - center_point[1])**2 < bandwidth**2)\n",
    "        if self.perturb_mode == 'up':\n",
    "            adata.X[:, adata.var.index.isin(genes)] = adata.X.max() * weights[:, np.newaxis]\n",
    "        elif self.perturb_mode == 'down':\n",
    "            adata.X[:, adata.var.index.isin(genes)] = adata.X[:, adata.var.index.isin(genes)] * (1-weights[:, np.newaxis])   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea7fe82e-556e-406d-9d3f-bb3ca0bb1849",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import spearmanr,pearsonr\n",
    "import seaborn as sns\n",
    "\n",
    "exlist = pd.read_excel(\"/stor/usr/sgenetmp/ischemicbrain/adg1323_Data_file_S2.xlsx\",sheet_name=['ICA','PIA_D','PIA_P'])\n",
    "\n",
    "mask = exlist['ICA'].iloc[:,0].isin(adatahvg_raw.var.gene_symbol)\n",
    "exlistflt = exlist['ICA'].loc[mask]\n",
    "exlistflt = exlistflt.sort_values('avg_logFC')\n",
    "\n",
    "amplifygene = exlistflt[exlistflt.avg_logFC>0].iloc[:,0].tolist()\n",
    "supressgene = exlistflt[exlistflt.avg_logFC<0].iloc[:,0].tolist()\n",
    "\n",
    "print(len(supressgene),supressgene, len(amplifygene),amplifygene)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e18f206",
   "metadata": {},
   "outputs": [],
   "source": [
    "adatahvg_raw = sc.read_h5ad(\"/stor/usr/sgenetmp/perturb/ischemic_49.h5ad\")\n",
    "valadata = adatahvg_raw[adatahvg_raw.obs['mask']=='False'].copy()\n",
    "adatahvg = adatahvg_raw[adatahvg_raw.obs['mask']=='True'].copy()\n",
    "valcoord = valadata.obs[['array_col', 'array_row']].values.astype(float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aed23753",
   "metadata": {},
   "outputs": [],
   "source": [
    "## perturbation visualization\n",
    "genes = adatahvg.var[adatahvg.var[\"gene_symbol\"].isin(amplifygene)].index.to_list()\n",
    "sc.pl.spatial(adatahvg,color=genes[:4],spot_size=0.03)\n",
    "\n",
    "UpRegulate = SpatialPerturbation(perturb_mode='up')\n",
    "UpRegulate.perturb_gaussian(adatahvg, genes, bandwidth=0.5, center_point=[4.5, 2])\n",
    "sc.pl.spatial(adatahvg,color=genes[:4],spot_size=0.03)\n",
    "\n",
    "genes = adatahvg.var[adatahvg.var[\"gene_symbol\"].isin(supressgene)].index.to_list()\n",
    "DownRegulate = SpatialPerturbation(perturb_mode='down')\n",
    "sc.pl.spatial(adatahvg,color=genes[:4],spot_size=0.03)\n",
    "\n",
    "DownRegulate.perturb_gaussian(adatahvg, genes, bandwidth=1, center_point=[4.5, 2])\n",
    "sc.pl.spatial(adatahvg,color=genes[:4], spot_size=0.03)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f941221c",
   "metadata": {},
   "source": [
    "# Generate perturbation group"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6ec2d1a-8b59-4635-a224-2fa2e62b46f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "figsize(4,4)\n",
    "expandadata,_,_ = expandTarget_quantize(adatahvg,model,valcoord,cfg,roundsize=(0.03,0.03),thres=50,verbose=True,mode='weighted')\n",
    "for i in range(1):\n",
    "    expandadata,_,_ = expandTarget_quantize(expandadata,model,valcoord,cfg,roundsize=(0.03,0.03),thres=50,verbose=True,mode='weighted')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dc96d37-0881-4f65-95d2-6623e717f918",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "valcoord = valadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "expandcoord = expandadata.obs[['array_col', 'array_row']].values.astype(float)\n",
    "valcoorddict = {tuple(coord): idx for idx, coord in enumerate(valcoord)}\n",
    "expandcoorddict = {tuple(coord): idx for idx, coord in enumerate(expandcoord)}\n",
    "commoncoord = set([tuple(row) for row in valcoord]).intersection(set([tuple(row) for row in expandcoord]))\n",
    "cvidx = [valcoorddict[c] for c in commoncoord]\n",
    "cidx = [expandcoorddict[c] for c in commoncoord]\n",
    "predval = expandadata[cidx,:].copy()\n",
    "subval = valadata[cvidx,:].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47236e72-2bb5-41eb-9696-00fcb13e4953",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "genelist = [\"ENSMUSG00000029304\",\"ENSMUSG00000020932\",\"ENSMUSG00000038331\",\"ENSMUSG00000027270\"]\n",
    "sc.pl.spatial(predval,color=genelist,spot_size=0.03,ncols=5)\n",
    "sc.pl.spatial(subval,color=genelist,spot_size=0.03,ncols=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20743670",
   "metadata": {},
   "outputs": [],
   "source": [
    "genelist = [\"ENSMUSG00000029304\",\"ENSMUSG00000020932\",\"ENSMUSG00000038331\",\"ENSMUSG00000027270\"]\n",
    "sc.pl.spatial(expandadata,spot_size=0.03,color= genelist)\n",
    "sc.pl.spatial(adatahvg_gt,spot_size=0.03,color= genelist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b285d1ec-8736-4331-a4f3-216325e24f9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "predval.write_h5ad(f'/stor/usr/sgenetmp/perturb_control/pred49_{config_file.parts[1]}_perturb_gradient.h5ad')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
