{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4cbdfa31-4148-4ecf-9e4a-b983f01b0aa2",
   "metadata": {},
   "source": [
    "# Evaluate LM-generated sequences and compare to real sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4be68f6-592f-468b-b534-21d340b6d5f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "import torch\n",
    "from plotnine import *\n",
    "\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "sys.path.append(\"../../\")\n",
    "import scripts.motifs, scripts.viz, scripts.regression\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86d8b54c-1f24-4891-9047-ad2f535aed94",
   "metadata": {},
   "source": [
    "## Load independent predictive models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcbe2872-951e-4216-b441-c185f4583dda",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmodel = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_complex/lightning_logs/version_2/checkpoints/epoch=1-step=30824.ckpt')\n",
    "dmodel = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_defined/lightning_logs/version_1/checkpoints/epoch=8-step=80568.ckpt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea374611-b9bb-4be6-9198-a4492595d4ec",
   "metadata": {},
   "source": [
    "## Load LM-generated and test set sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c89221fc-8471-4ce2-af81-7c1b44a3b281",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = pd.read_csv('synthetic_promoters/lm_filtered.csv', index_col=0, dtype='str')\n",
    "gen.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d771526a-9561-439d-a204-6f2920c1a02b",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('lm_data/test.csv', index_col=0, dtype='str', usecols=(0, 1, 6))\n",
    "test = test[test.label.isin(gen.label)]\n",
    "test.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "299b0640-669a-41be-bf79-e9d5bb396c66",
   "metadata": {},
   "source": [
    "## Make predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3311c61a-beb9-4d16-8be3-bc58c79a964a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = scripts.regression.SeqDataset(gen.Sequence.tolist())\n",
    "gen['Complex'] = cmodel.predict_on_dataset(ds, devices=[0], num_workers=8)\n",
    "gen['Defined'] = dmodel.predict_on_dataset(ds, devices=[0], num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e53edc6-0934-46bf-851a-2d4842ac1d69",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = scripts.regression.SeqDataset(test.seq.tolist())\n",
    "test['Complex'] = cmodel.predict_on_dataset(ds, devices=[0], num_workers=8)\n",
    "test['Defined'] = dmodel.predict_on_dataset(ds, devices=[0], num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23626bf9-eeca-4850-98a5-f59e08ee0ad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = test.rename(columns={'seq':'Sequence'})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "867f1964-db63-4811-889d-5a75925b4c56",
   "metadata": {},
   "source": [
    "## Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd092d56-14ef-4266-82ab-d68c3e226301",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen['Group']='regLM'\n",
    "test['Group'] = 'Test Set'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e50fb5ba-4e77-4525-92e1-1612eb546974",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = (\n",
    "    ggplot(pd.concat([gen, test]).melt(id_vars=['Sequence', 'label', 'Group']), \n",
    "           aes(x='label', y='value', fill='Group'))\n",
    "    + geom_boxplot(outlier_size=.001, size=.35, alpha=.8)\n",
    "    + facet_wrap(\"variable\", ncol=2)\n",
    "    + theme_classic()\n",
    "    + theme(figure_size=(3.8, 2.2))\n",
    "    + xlab(\"Label\")\n",
    "    + ylab(\"Predicted Activity\")\n",
    ")\n",
    "p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8f760cd-4215-42d9-a043-6c9309be670d",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen.to_csv('synthetic_promoters/lm_filtered_pred.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "642ca71d-3b12-41a6-b35b-d2abc1d6369a",
   "metadata": {},
   "source": [
    "## Edit distance between LM-generated sequences and training set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4461e9cb-8d1a-4e82-9924-180b038f62f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('lm_data/train.csv', index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0090182f-08c8-4fca-8106-e834ec623699",
   "metadata": {},
   "outputs": [],
   "source": [
    "seqs = pd.concat([\n",
    "    pd.DataFrame({'Sequence':gen.Sequence.tolist(), 'Group':'regLM'}),\n",
    "    pd.DataFrame({'Sequence':train.seq.tolist(), 'Group':'Training Set'}),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3559eb75-c08d-486f-a3a8-06f1b14c1354",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "from scripts.sequence import min_edit_distance_from_reference\n",
    "\n",
    "e = min_edit_distance_group(seqs, 'Training Set')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b21ed0af-d874-47b7-b300-5b5956bd5bca",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    ggplot(e, aes(x='edit')) + geom_histogram()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5d3d16f-1776-419c-9374-beeaee8b2c6c",
   "metadata": {},
   "source": [
    "## Motif abundance vs. strength"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd571981-0dd9-4915-9612-b23f9208a931",
   "metadata": {},
   "outputs": [],
   "source": [
    "act = ['SPT15_2172_0', 'PUT3_2223_0', 'GAL4_2126_0', 'HAA1_1425_0', \n",
    "       'PDR3_1387_0', 'NDT80_2145_0', 'MBP1_500_0', 'RSC3_2165_0', \n",
    "       'ADR1_623_0', 'MSN2_1381_0']\n",
    "rep = [\"ASH1_28_0\", \"MOT3_193_0\", \"DOT6_557_0\",  \"MATALPHA2_2212_0\", \"DAL80_636_0\", \"ROX1_537_0\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08872fe5-c85b-452f-a4e9-b13242676973",
   "metadata": {},
   "outputs": [],
   "source": [
    "pwms = pd.read_hdf('/gstore/data/resbioai/lala8/yetfasco_1.02/pms.hdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "169ddc1b-99cd-4caa-a11e-fd5299acb272",
   "metadata": {},
   "outputs": [],
   "source": [
    "pwms = pwms[pwms.index.isin(act+rep)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0d45674-7052-49dd-9ab4-eeb16b642d9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "gen.index = gen.index.astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51408241-88b6-411e-9a44-78a87b17ae89",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "sites = scripts.motifs.scan_seqs_motifs(gen, pwms, num_workers=8)\n",
    "cts = scripts.motifs.calculate_motif_counts(sites)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f805aafb-e978-412d-9063-00fee527bfe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "cts.obs = cts.obs.merge(gen, left_index=True, right_index=True, how='left')\n",
    "cts.obs['Group'] = 'regLM'\n",
    "cts.var_names= [x.split('_')[0] for x in cts.var_names]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "584739fb-cb90-4846-a27c-183b71b55eb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "motifs = {'Activators':[x.split('_')[0] for x in act], 'Repressors':[x.split('_')[0] for x in rep]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c122aa8-a10b-4223-9639-8f23e84866bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.plot_motif_freq_by_label(cts, motifs = motifs) +\\\n",
    "theme(figure_size=(6.5, 2.5))+\\\n",
    "facet_grid(\"~Category\", scales=\"free_x\", space={'x': [2, 1]})+ \\\n",
    "theme(axis_text_x=element_text(rotation=40, hjust=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c6653bd-7c66-420a-aeb4-19b8b7eeca7b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
