{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scanpy as sc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load the single-cell data using scanpy. The single-cell data is stored in a special data structure called AnnData (short: adata)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AnnData object with n_obs × n_vars = 316138 × 36604\n",
       "    obs: 'condition', 'target_gene'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adata = sc.read('L008_rna_counts.h5ad')\n",
    "adata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The counts are stored in adata.X\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<316138x36604 sparse matrix of type '<class 'numpy.float32'>'\n",
       "\twith 1217701057 stored elements in Compressed Sparse Row format>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adata.X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "adata.obs is a dataframe containing annotation for each cell, such as e.g. batch, cell type, perturbation. Or other technical annotations at the cell level."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>condition</th>\n",
       "      <th>target_gene</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2, Donor#2 Cas9+ IFNg+</td>\n",
       "      <td>ZNF331</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2, Donor#2 Cas9+ IFNg+</td>\n",
       "      <td>MAPK11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2, Donor#2 Cas9+ IFNg+</td>\n",
       "      <td>OAS3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2, Donor#2 Cas9+ IFNg+</td>\n",
       "      <td>IL7R</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2, Donor#2 Cas9+ IFNg+</td>\n",
       "      <td>SOX4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>316133</th>\n",
       "      <td>3, Donor#3 Cas9+ restim</td>\n",
       "      <td>NA</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>316134</th>\n",
       "      <td>3, Donor#3 Cas9+ restim</td>\n",
       "      <td>NA</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>316135</th>\n",
       "      <td>3, Donor#3 Cas9+ restim</td>\n",
       "      <td>NA</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>316136</th>\n",
       "      <td>3, Donor#3 Cas9+ restim</td>\n",
       "      <td>NA</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>316137</th>\n",
       "      <td>3, Donor#3 Cas9+ restim</td>\n",
       "      <td>NA</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>316138 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                      condition target_gene\n",
       "0        2, Donor#2 Cas9+ IFNg+      ZNF331\n",
       "1        2, Donor#2 Cas9+ IFNg+      MAPK11\n",
       "2        2, Donor#2 Cas9+ IFNg+        OAS3\n",
       "3        2, Donor#2 Cas9+ IFNg+        IL7R\n",
       "4        2, Donor#2 Cas9+ IFNg+        SOX4\n",
       "...                         ...         ...\n",
       "316133  3, Donor#3 Cas9+ restim          NA\n",
       "316134  3, Donor#3 Cas9+ restim          NA\n",
       "316135  3, Donor#3 Cas9+ restim          NA\n",
       "316136  3, Donor#3 Cas9+ restim          NA\n",
       "316137  3, Donor#3 Cas9+ restim          NA\n",
       "\n",
       "[316138 rows x 2 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adata.obs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "adata = adata[adata.obs['target_gene'] != 'NA']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_21508/2022005586.py:1: ImplicitModificationWarning: Trying to modify attribute `.obs` of view, initializing view as actual.\n",
      "  adata.obs['donor'] = adata.obs['condition'].apply(lambda x: x.split(' ')[1])\n"
     ]
    }
   ],
   "source": [
    "adata.obs['donor'] = adata.obs['condition'].apply(lambda x: x.split(' ')[1])\n",
    "adata.obs['crispr'] = adata.obs['condition'].apply(lambda x: x.split(' ')[2])\n",
    "adata.obs['cell_type'] = adata.obs['condition'].apply(lambda x: x.split(' ')[3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "adata.var is a dataframe containing annotation for each gene, usually some statistics such as dispersion, or gene names, pathways etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>MIR1302-2HG</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>FAM138A</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OR4F5</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AL627309.1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AL627309.3</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AC007325.4</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AC007325.2</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>GFP</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RFP670</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>REF-polyA</th>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>36604 rows × 0 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: []\n",
       "Index: [MIR1302-2HG, FAM138A, OR4F5, AL627309.1, AL627309.3, AL627309.2, AL627309.5, AL627309.4, AP006222.2, AL732372.1, OR4F29, AC114498.1, OR4F16, AL669831.2, LINC01409, FAM87B, LINC01128, LINC00115, FAM41C, AL645608.6, AL645608.2, AL645608.4, LINC02593, SAMD11, NOC2L, KLHL17, PLEKHN1, PERM1, AL645608.7, HES4, ISG15, AL645608.1, AGRN, AL645608.5, AL645608.8, RNF223, C1orf159, AL390719.3, LINC01342, AL390719.2, TTLL10-AS1, TTLL10, TNFRSF18, TNFRSF4, SDF4, B3GALT6, C1QTNF12, AL162741.1, UBE2J2, LINC01786, SCNN1D, ACAP3, PUSL1, INTS11, AL139287.1, CPTP, TAS1R3, DVL1, MXRA8, AURKAIP1, CCNL2, MRPL20-AS1, MRPL20, AL391244.2, ANKRD65, AL391244.1, TMEM88B, LINC01770, VWA1, ATAD3C, ATAD3B, ATAD3A, TMEM240, SSU72, AL645728.1, FNDC10, AL691432.4, AL691432.2, MIB2, MMP23B, CDK11B, FO704657.1, SLC35E2B, CDK11A, SLC35E2A, NADK, GNB1, AL109917.1, CALML6, TMEM52, CFAP74, AL391845.2, GABRD, AL391845.1, PRKCZ, AL590822.2, PRKCZ-AS1, FAAP20, AL590822.1, SKI, ...]\n",
       "\n",
       "[36604 rows x 0 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adata.var"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quality control\n",
    "Check the quality of the data and see if it's necessary to remove some cells."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.obs['n_counts'] = np.ravel(adata.X.sum(1)) #number of counts in the cell\n",
    "adata.obs['n_genes'] = np.ravel(np.sum(adata.X > 0, axis=1)) #number of genes with at least 1 count per cell\n",
    "adata.var['mito'] = adata.var_names.str.contains(\"MT-\") #flag for mitochondrial genes\n",
    "adata.obs['mt_frac'] = np.ravel(adata.X[:, adata.var.mito].sum(1)) / adata.obs['n_counts'].values #fraction of mitochondrial gene exp, high values mean dead or bad quality cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "View of AnnData object with n_obs × n_vars = 139564 × 36604\n",
       "    obs: 'condition', 'target_gene', 'donor', 'crispr', 'cell_type', 'n_counts', 'n_genes', 'mt_frac'\n",
       "    var: 'mito'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# filtering\n",
    "adata = adata[adata.obs['n_counts'] > 500]\n",
    "adata = adata[adata.obs['n_genes'] > 750]\n",
    "adata = adata[adata.obs['mt_frac'] < 0.2]\n",
    "adata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3709.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adata.X.max() #check it's an int, to make sure it's count data and not preprocessed data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/anaconda3/envs/gvci-env/lib/python3.9/site-packages/scanpy/preprocessing/_normalization.py:170: UserWarning: Received a view of an AnnData. Making a copy.\n",
      "  view_to_actual(adata)\n"
     ]
    }
   ],
   "source": [
    "sc.pp.normalize_total(adata)\n",
    "#adata.layers['counts'] = adata.X.copy()\n",
    "sc.pp.log1p(adata)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Feature (gene) selection\n",
    "We select only the top N most variable genes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.pp.filter_genes(adata, min_cells=1)\n",
    "sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor='cell_ranger')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "gene_targeted = [n in adata.obs['target_gene'].cat.categories for n in adata.var_names]\n",
    "adata = adata[:, [a or b for a, b in zip(gene_targeted, adata.var.highly_variable)]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "View of AnnData object with n_obs × n_vars = 139564 × 2048\n",
       "    obs: 'condition', 'target_gene', 'donor', 'crispr', 'cell_type', 'n_counts', 'n_genes', 'mt_frac'\n",
       "    var: 'mito', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'\n",
       "    uns: 'log1p', 'hvg'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Out-of-distribution selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for cond1 in adata.obs.target_gene.unique():\n",
    "    ad1 = adata[adata.obs.target_gene == cond1]\n",
    "    ad2 = adata[adata.obs.target_gene != cond1]\n",
    "    mean1 = ad1.X.mean(0)\n",
    "    mean2 = ad2.X.mean(0)\n",
    "    l2 = np.linalg.norm(mean1-mean2)\n",
    "    results.append({\n",
    "        'cond1': cond1,\n",
    "        'L2': l2\n",
    "    })\n",
    "df_vs_rest = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pick biggest signals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>cond1</th>\n",
       "      <th>L2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>77</th>\n",
       "      <td>SRSF7</td>\n",
       "      <td>3.169078</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>79</th>\n",
       "      <td>IL2RG</td>\n",
       "      <td>3.196164</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>56</th>\n",
       "      <td>MX1</td>\n",
       "      <td>3.218410</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>51</th>\n",
       "      <td>RAC2</td>\n",
       "      <td>3.303963</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>NR4A2</td>\n",
       "      <td>3.317241</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>71</th>\n",
       "      <td>SARAF</td>\n",
       "      <td>3.317464</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>F8</td>\n",
       "      <td>3.447231</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>70</th>\n",
       "      <td>CD247</td>\n",
       "      <td>3.464400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64</th>\n",
       "      <td>IRF9</td>\n",
       "      <td>3.514952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>63</th>\n",
       "      <td>THY1</td>\n",
       "      <td>3.525789</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>OAS3</td>\n",
       "      <td>3.573775</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>73</th>\n",
       "      <td>JUNB</td>\n",
       "      <td>3.782597</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>STAG3</td>\n",
       "      <td>3.953121</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>PDCD1</td>\n",
       "      <td>4.037519</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>ENTPD1</td>\n",
       "      <td>4.330400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>FURIN</td>\n",
       "      <td>4.549811</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MAPK11</td>\n",
       "      <td>4.631763</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>CD44</td>\n",
       "      <td>4.636609</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>SLC2A3</td>\n",
       "      <td>4.657693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ZNF331</td>\n",
       "      <td>4.920850</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     cond1        L2\n",
       "77   SRSF7  3.169078\n",
       "79   IL2RG  3.196164\n",
       "56     MX1  3.218410\n",
       "51    RAC2  3.303963\n",
       "36   NR4A2  3.317241\n",
       "71   SARAF  3.317464\n",
       "7       F8  3.447231\n",
       "70   CD247  3.464400\n",
       "64    IRF9  3.514952\n",
       "63    THY1  3.525789\n",
       "2     OAS3  3.573775\n",
       "73    JUNB  3.782597\n",
       "31   STAG3  3.953121\n",
       "45   PDCD1  4.037519\n",
       "14  ENTPD1  4.330400\n",
       "44   FURIN  4.549811\n",
       "1   MAPK11  4.631763\n",
       "21    CD44  4.636609\n",
       "13  SLC2A3  4.657693\n",
       "0   ZNF331  4.920850"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_vs_rest.sort_values(by='L2').tail(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['SRSF7', 'IL2RG', 'MX1', 'RAC2', 'NR4A2', 'SARAF', 'F8', 'CD247',\n",
       "       'IRF9', 'THY1', 'OAS3', 'JUNB', 'STAG3', 'PDCD1', 'ENTPD1',\n",
       "       'FURIN', 'MAPK11', 'CD44', 'SLC2A3', 'ZNF331'], dtype=object)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "KO_OOD = df_vs_rest.sort_values(by='L2').tail(20).cond1.values\n",
    "KO_OOD"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare for the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/anaconda3/envs/gvci-env/lib/python3.9/site-packages/anndata/compat/_overloaded_dict.py:106: ImplicitModificationWarning: Trying to modify attribute `._uns` of view, initializing view as actual.\n",
      "  self.data[key] = value\n"
     ]
    }
   ],
   "source": [
    "adata.uns['fields'] = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.obs['target_gene'] = adata.obs['target_gene'].astype(str)\n",
    "adata.uns['fields']['perturbation'] = 'target_gene'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.obs['control'] = [1 if x == 'Cas9-' else 0 for x in adata.obs['crispr'].values]\n",
    "adata.obs.loc[adata.obs['control'] == 1, 'target_gene'] = 'ctrl'\n",
    "adata.uns['fields']['control'] = 'control'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.obs['dose'] = 1.0\n",
    "adata.uns['fields']['dose'] = 'dose'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.uns['fields']['covariates'] = ['cell_type', 'donor']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "del adata.obs['condition']\n",
    "del adata.uns['log1p']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split dataset\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "adata.obs['split'] = 'NA'\n",
    "adata.uns['fields']['split'] = 'split'\n",
    "\n",
    "adata.obs.loc[\n",
    "    (adata.obs['cell_type'] == 'restim') & (adata.obs['target_gene'].isin(KO_OOD)),\n",
    "    'split'\n",
    "] = 'ood'\n",
    "\n",
    "idx = np.where(adata.obs['split']=='NA')[0]\n",
    "idx_train, idx_test = train_test_split(idx, test_size=0.2, random_state=42)\n",
    "\n",
    "adata.obs.iloc[idx_train, adata.obs.columns.get_loc('split')] = 'train'\n",
    "adata.obs.iloc[idx_test, adata.obs.columns.get_loc('split')] = 'test'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Rank DE genes (optional)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/anaconda3/envs/gvci-env/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# this will be done in main script if it's not done here\n",
    "\n",
    "cov_names = []\n",
    "for cov in adata.uns['fields']['covariates']:\n",
    "    cov_names.append(np.array(adata.obs[cov].values))\n",
    "cov_names = [\"_\".join(c) for c in zip(*cov_names)]\n",
    "adata.obs[\"cov_name\"] = cov_names\n",
    "\n",
    "cov_pert_names = []\n",
    "for i in range(len(adata)):\n",
    "    comb_name = (\n",
    "        f\"{adata.obs['cov_name'].values[i]}\"\n",
    "        f\"_{adata.obs[adata.uns['fields']['perturbation']].values[i]}\"\n",
    "    )\n",
    "    cov_pert_names.append(comb_name)\n",
    "adata.obs[\"cov_pert_name\"] = cov_pert_names\n",
    "\n",
    "import warnings\n",
    "\n",
    "from vci.utils.data_utils import rank_genes_groups\n",
    "\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    rank_genes_groups(adata,\n",
    "        groupby=\"cov_pert_name\",\n",
    "        reference=\"cov_name\",\n",
    "        control_key=\"control\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>target_gene</th>\n",
       "      <th>donor</th>\n",
       "      <th>crispr</th>\n",
       "      <th>cell_type</th>\n",
       "      <th>n_counts</th>\n",
       "      <th>n_genes</th>\n",
       "      <th>mt_frac</th>\n",
       "      <th>control</th>\n",
       "      <th>dose</th>\n",
       "      <th>split</th>\n",
       "      <th>cov_name</th>\n",
       "      <th>cov_pert_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ZNF331</td>\n",
       "      <td>Donor#2</td>\n",
       "      <td>Cas9+</td>\n",
       "      <td>IFNg+</td>\n",
       "      <td>4055.0</td>\n",
       "      <td>1881</td>\n",
       "      <td>0.123551</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>test</td>\n",
       "      <td>IFNg+_Donor#2</td>\n",
       "      <td>IFNg+_Donor#2_ZNF331</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MAPK11</td>\n",
       "      <td>Donor#2</td>\n",
       "      <td>Cas9+</td>\n",
       "      <td>IFNg+</td>\n",
       "      <td>1997.0</td>\n",
       "      <td>1073</td>\n",
       "      <td>0.019529</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>train</td>\n",
       "      <td>IFNg+_Donor#2</td>\n",
       "      <td>IFNg+_Donor#2_MAPK11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>OAS3</td>\n",
       "      <td>Donor#2</td>\n",
       "      <td>Cas9+</td>\n",
       "      <td>IFNg+</td>\n",
       "      <td>11044.0</td>\n",
       "      <td>3155</td>\n",
       "      <td>0.045817</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>train</td>\n",
       "      <td>IFNg+_Donor#2</td>\n",
       "      <td>IFNg+_Donor#2_OAS3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>IL7R</td>\n",
       "      <td>Donor#2</td>\n",
       "      <td>Cas9+</td>\n",
       "      <td>IFNg+</td>\n",
       "      <td>38119.0</td>\n",
       "      <td>6660</td>\n",
       "      <td>0.040295</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>train</td>\n",
       "      <td>IFNg+_Donor#2</td>\n",
       "      <td>IFNg+_Donor#2_IL7R</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>SOX4</td>\n",
       "      <td>Donor#2</td>\n",
       "      <td>Cas9+</td>\n",
       "      <td>IFNg+</td>\n",
       "      <td>2585.0</td>\n",
       "      <td>1561</td>\n",
       "      <td>0.042553</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>test</td>\n",
       "      <td>IFNg+_Donor#2</td>\n",
       "      <td>IFNg+_Donor#2_SOX4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>316126</th>\n",
       "      <td>TOX</td>\n",
       "      <td>Donor#3</td>\n",
       "      <td>Cas9+</td>\n",
       "      <td>restim</td>\n",
       "      <td>19578.0</td>\n",
       "      <td>4508</td>\n",
       "      <td>0.069670</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>test</td>\n",
       "      <td>restim_Donor#3</td>\n",
       "      <td>restim_Donor#3_TOX</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>316127</th>\n",
       "      <td>SLC2A3</td>\n",
       "      <td>Donor#3</td>\n",
       "      <td>Cas9+</td>\n",
       "      <td>restim</td>\n",
       "      <td>3733.0</td>\n",
       "      <td>1689</td>\n",
       "      <td>0.091347</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>ood</td>\n",
       "      <td>restim_Donor#3</td>\n",
       "      <td>restim_Donor#3_SLC2A3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>316128</th>\n",
       "      <td>BTG2</td>\n",
       "      <td>Donor#3</td>\n",
       "      <td>Cas9+</td>\n",
       "      <td>restim</td>\n",
       "      <td>32049.0</td>\n",
       "      <td>5401</td>\n",
       "      <td>0.074043</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>train</td>\n",
       "      <td>restim_Donor#3</td>\n",
       "      <td>restim_Donor#3_BTG2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>316131</th>\n",
       "      <td>NTC</td>\n",
       "      <td>Donor#3</td>\n",
       "      <td>Cas9+</td>\n",
       "      <td>restim</td>\n",
       "      <td>30929.0</td>\n",
       "      <td>5601</td>\n",
       "      <td>0.070258</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>train</td>\n",
       "      <td>restim_Donor#3</td>\n",
       "      <td>restim_Donor#3_NTC</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>316132</th>\n",
       "      <td>IL2RA</td>\n",
       "      <td>Donor#3</td>\n",
       "      <td>Cas9+</td>\n",
       "      <td>restim</td>\n",
       "      <td>4731.0</td>\n",
       "      <td>2056</td>\n",
       "      <td>0.039104</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>train</td>\n",
       "      <td>restim_Donor#3</td>\n",
       "      <td>restim_Donor#3_IL2RA</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>139564 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       target_gene    donor crispr cell_type  n_counts  n_genes   mt_frac  \\\n",
       "0           ZNF331  Donor#2  Cas9+     IFNg+    4055.0     1881  0.123551   \n",
       "1           MAPK11  Donor#2  Cas9+     IFNg+    1997.0     1073  0.019529   \n",
       "2             OAS3  Donor#2  Cas9+     IFNg+   11044.0     3155  0.045817   \n",
       "3             IL7R  Donor#2  Cas9+     IFNg+   38119.0     6660  0.040295   \n",
       "4             SOX4  Donor#2  Cas9+     IFNg+    2585.0     1561  0.042553   \n",
       "...            ...      ...    ...       ...       ...      ...       ...   \n",
       "316126         TOX  Donor#3  Cas9+    restim   19578.0     4508  0.069670   \n",
       "316127      SLC2A3  Donor#3  Cas9+    restim    3733.0     1689  0.091347   \n",
       "316128        BTG2  Donor#3  Cas9+    restim   32049.0     5401  0.074043   \n",
       "316131         NTC  Donor#3  Cas9+    restim   30929.0     5601  0.070258   \n",
       "316132       IL2RA  Donor#3  Cas9+    restim    4731.0     2056  0.039104   \n",
       "\n",
       "        control  dose  split        cov_name          cov_pert_name  \n",
       "0             0   1.0   test   IFNg+_Donor#2   IFNg+_Donor#2_ZNF331  \n",
       "1             0   1.0  train   IFNg+_Donor#2   IFNg+_Donor#2_MAPK11  \n",
       "2             0   1.0  train   IFNg+_Donor#2     IFNg+_Donor#2_OAS3  \n",
       "3             0   1.0  train   IFNg+_Donor#2     IFNg+_Donor#2_IL7R  \n",
       "4             0   1.0   test   IFNg+_Donor#2     IFNg+_Donor#2_SOX4  \n",
       "...         ...   ...    ...             ...                    ...  \n",
       "316126        0   1.0   test  restim_Donor#3     restim_Donor#3_TOX  \n",
       "316127        0   1.0    ood  restim_Donor#3  restim_Donor#3_SLC2A3  \n",
       "316128        0   1.0  train  restim_Donor#3    restim_Donor#3_BTG2  \n",
       "316131        0   1.0  train  restim_Donor#3     restim_Donor#3_NTC  \n",
       "316132        0   1.0  train  restim_Donor#3   restim_Donor#3_IL2RA  \n",
       "\n",
       "[139564 rows x 12 columns]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adata.obs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "adata.write('L008_prepped.h5ad')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.0 ('gvci-env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "a18dafcc48613bb6c3783e5aa2cfbbab56400df7d9096bc06931f344217b23cf"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
