{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fd3bf76-0dfc-4b36-aaf0-e24ebc14be3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "import torch\n",
    "from plotnine import *\n",
    "\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "sys.path.append(\"../../regLM/\")\n",
    "import reglm.dataset, reglm.lightning, reglm.interpret\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2\"\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5de159a1-660d-463f-9b79-410dab5a9436",
   "metadata": {},
   "source": [
    "## Load regLM model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83ed70f1-70b6-4efb-9ef1-d82dfa970064",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = reglm.lightning.LightningModel.load_from_checkpoint(\n",
    "    'lightning_logs/version_10/checkpoints/epoch=9-step=580648.ckpt').to(torch.device(0))\n",
    "model.label_len = 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f552389b-4484-4254-83fa-15209d092fb0",
   "metadata": {},
   "source": [
    "## Load PWMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddb2a02d-87ee-40e6-9bbb-75cbfbef727f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pwms = pd.read_hdf('/gstore/data/resbioai/lala8/yetfasco_1.02/pms.hdf')\n",
    "idxs_to_base_dict = {0:'A', 1:'C', 2:'G', 3:'T'}\n",
    "pwms['consensus'] = pwms.weights.apply(\n",
    "    lambda x: ''.join([idxs_to_base_dict[x] for x in x.argmax(0)])\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "392a87b9-2a7a-4adb-a968-20e44a3e4bc9",
   "metadata": {},
   "source": [
    "## Insert activating and repressing TF motifs and compute log-likelihood ratio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e23c3d0-561d-4efe-b5aa-ef0371803875",
   "metadata": {},
   "outputs": [],
   "source": [
    "act = [\n",
    "    'SPT15_2172_0',\n",
    "    'PUT3_2223_0',\n",
    "    'GAL4_2126_0',\n",
    "    'HAA1_1425_0',\n",
    "    'PDR3_1387_0',\n",
    "    'NDT80_2145_0',\n",
    "    'MBP1_500_0',\n",
    "    'RSC3_2165_0',\n",
    "    'ADR1_623_0',\n",
    "    'MSN2_1381_0',\n",
    "    'ASH1_1474_0',\n",
    "]\n",
    "\n",
    "rep = [\"ASH1_28_0\",\n",
    "            \"MOT3_193_0\",\n",
    "            \"DOT6_557_0\", \n",
    "            \"MATALPHA2_2212_0\",\n",
    "            \"DAL80_636_0\",\n",
    "            \"ROX1_537_0\", \n",
    "           ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a4af7ac-8e85-4969-8a3b-0b8df736445a",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = reglm.interpret.motif_insert(pwms[pwms.index.isin(act + rep)], \n",
    "                                   model, label='44', ref_label='00', n=100, seq_len=40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e3b44b-db12-427d-8a1f-067530734ad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "out.loc[out.Motif.isin(act), 'Category'] = 'Activators'\n",
    "out.loc[out.Motif.isin(rep), 'Category'] = 'Repressors'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c21e8c5-5e29-4923-9a05-464b1ef3cb6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "out['TF Motif'] = [x.split('_')[0] for x in out.Motif]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eb1f418-01d1-47df-a7f1-21f7b3fa59cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "p=(\n",
    "    ggplot(out, aes(x='TF Motif', y='LL_ratio'))\n",
    "    + geom_boxplot(outlier_size=.1, size=.4)\n",
    "    + geom_hline(yintercept = 0, linetype=\"dashed\")\n",
    "    + ylab('Log-likelihood\\nratio (44/00)')\n",
    "    + theme_classic()\n",
    "    + theme(figure_size=(3,2.5))\n",
    "    + facet_wrap('Category', scales='free', ncol=2)\n",
    "    + theme(axis_text_x=element_text(rotation=60, hjust=1))\n",
    ")\n",
    "p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21bf897c-0331-4a08-a881-a757810eac12",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
