{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9cb3c0ed-6e96-4f3f-b752-bf9d22e9c770",
   "metadata": {},
   "source": [
    "# Train regression models on Vaishnav et al. yeast promoter GPRA sequences measured in one medium"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51863bf5-faee-4d4b-8a79-ef01bcba020b",
   "metadata": {},
   "source": [
    "This regression model will be independent of all generative models and will be used for independent evaluation of synthetic promoters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "742c67a0-a3cc-4fb8-9689-e5a2c4fe0175",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "from plotnine import *\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "sys.path.append('../../')\n",
    "import scripts.viz, scripts.regression"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbcfd100-e2a0-41d0-beed-ec495bf403c9",
   "metadata": {},
   "source": [
    "## Load sequences that were measured in only 1 medium"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "171e7d1b-77a1-46d7-8636-8b5e4ef11569",
   "metadata": {},
   "outputs": [],
   "source": [
    "defined = pd.read_csv('processed_data/random_defined.csv', index_col=0, usecols=(0,1, 2))\n",
    "complex = pd.read_csv('processed_data/random_complex.csv', index_col=0, usecols=(0,1, 2))\n",
    "\n",
    "len(defined), len(complex)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff03d230-9bc1-4993-b789-88575d37eb0b",
   "metadata": {},
   "source": [
    "## Split data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f37b831-2464-4923-8a2e-5bb445324f06",
   "metadata": {},
   "outputs": [],
   "source": [
    "#os.mkdir('separate_regression_data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3978fc67-afca-4b19-9d0b-6fad191c4e62",
   "metadata": {},
   "outputs": [],
   "source": [
    "cval = complex.sample(50000).copy()\n",
    "dval = defined.sample(50000).copy()\n",
    "ctrain = complex[~complex.index.isin(cval.index)]\n",
    "dtrain = defined[~defined.index.isin(dval.index)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d375d5fc-8b98-4935-bbbc-db5057192844",
   "metadata": {},
   "outputs": [],
   "source": [
    "ctest = ctrain.sample(50000).copy()\n",
    "dtest = dtrain.sample(50000).copy()\n",
    "\n",
    "ctrain = ctrain[~ctrain.index.isin(ctest.index)]\n",
    "dtrain = dtrain[~dtrain.index.isin(dtest.index)]\n",
    "\n",
    "print(len(ctrain), len(cval), len(ctest))\n",
    "print(len(dtrain), len(dval), len(dtest)) \n",
    "\n",
    "ctrain.to_csv('separate_regression_data/ctrain.csv')\n",
    "cval.to_csv('separate_regression_data/cval.csv')\n",
    "ctest.to_csv('separate_regression_data/ctest.csv')\n",
    "\n",
    "dtrain.to_csv('separate_regression_data/dtrain.csv')\n",
    "dval.to_csv('separate_regression_data/dval.csv')\n",
    "dtest.to_csv('separate_regression_data/dtest.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "375c34cb-d75f-4ec6-a446-d1abe5661c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "ctrain_ds = scripts.regression.SeqDataset(ctrain, rc=True)\n",
    "dtrain_ds = scripts.regression.SeqDataset(dtrain, rc=True)\n",
    "\n",
    "cval_ds = scripts.regression.SeqDataset(cval, rc=False)\n",
    "dval_ds = scripts.regression.SeqDataset(dval, rc=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5858721b-241b-44cb-b86a-2070d2b0c3bd",
   "metadata": {},
   "source": [
    "## Build model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51c8b12b-9d13-4f7f-ac9a-a415c1947d8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmodel = scripts.regression.LightningModel(model_type='EnformerModel', loss='Poisson', \n",
    "                                       lr=5e-4, n_tasks=1, dim=384)\n",
    "\n",
    "dmodel = scripts.regression.LightningModel(model_type='EnformerModel', loss='Poisson', \n",
    "                                       lr=5e-4, n_tasks=1, dim=384)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0297ea3-f8d1-4476-bc8d-179c4d9e2708",
   "metadata": {},
   "source": [
    "## Train models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a152507-d8f9-40a0-82f2-506262b0ff73",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cmodel.train_on_dataset(ctrain_ds, cval_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db10aa4-5664-44a0-bd36-6211a44f1126",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "dmodel.train_on_dataset(dtrain_ds, dval_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93694897-8d73-4857-a342-5f84c4f80fce",
   "metadata": {},
   "source": [
    "## Evaluate models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75279c83-fa2a-4af6-ac77-c1417e8f9d5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmodel = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_complex/lightning_logs/version_2/checkpoints/epoch=1-step=30824.ckpt')\n",
    "dmodel = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "'reg_defined/lightning_logs/version_1/checkpoints/epoch=8-step=80568.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d0fdeb1-4423-4fa8-b9a1-7729d03fc16a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ctest = pd.read_csv('separate_regression_data/ctest.csv', index_col=0)\n",
    "ctest['pred'] = cmodel.predict_on_dataset(\n",
    "    scripts.regression.SeqDataset(ctest), devices=[2], num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4e9096a-4863-468b-995f-8102a80d8e03",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.pointdensityplot(\n",
    "    ctest, true_col='exp', pred_col='pred', corrx=4, corry=19)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5afa532d-7e39-41c3-99f7-027affbc77d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dtest = pd.read_csv('separate_regression_data/dtest.csv', index_col=0)\n",
    "dtest['pred'] = dmodel.predict_on_dataset(\n",
    "    scripts.regression.SeqDataset(dtest), devices=[2], num_workers=8).cpu().detach().numpy().squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b71956f4-f9ab-4232-b4d7-d0b9b8e8389c",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.pointdensityplot(\n",
    "    dtest, true_col='exp', pred_col='pred', corrx=4, corry=19)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
