{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9cb3c0ed-6e96-4f3f-b752-bf9d22e9c770",
   "metadata": {},
   "source": [
    "# Train regression models on Vaishnav et al. yeast promoter GPRA sequences measured in one medium"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51863bf5-faee-4d4b-8a79-ef01bcba020b",
   "metadata": {},
   "source": [
    "This regression model will be independent of all generative models and will be used for independent evaluation of synthetic promoters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "742c67a0-a3cc-4fb8-9689-e5a2c4fe0175",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "from plotnine import *\n",
    "np.random.seed(0)\n",
    "sys.path.append('../../')\n",
    "import scripts.viz, scripts.regression"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbcfd100-e2a0-41d0-beed-ec495bf403c9",
   "metadata": {},
   "source": [
    "## Load sequences that were measured in both media"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "171e7d1b-77a1-46d7-8636-8b5e4ef11569",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('processed_data/random_both.csv', index_col=0, usecols=(0,1, 2,4))\n",
    "print(len(data))\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff03d230-9bc1-4993-b789-88575d37eb0b",
   "metadata": {},
   "source": [
    "## Split data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f37b831-2464-4923-8a2e-5bb445324f06",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.mkdir('joint_regression_data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3978fc67-afca-4b19-9d0b-6fad191c4e62",
   "metadata": {},
   "outputs": [],
   "source": [
    "val = data.sample(50000).copy()\n",
    "train = data[~data.index.isin(val.index)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aa17562-5d18-45b9-8494-2c98ba9b00e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = train.sample(50000).copy()\n",
    "train = train[~train.index.isin(test.index)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41a1a190-86c5-4988-bb59-d95becb94752",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(train), len(val), len(test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d375d5fc-8b98-4935-bbbc-db5057192844",
   "metadata": {},
   "outputs": [],
   "source": [
    "train.to_csv('joint_regression_data/train.csv')\n",
    "val.to_csv('joint_regression_data/val.csv')\n",
    "test.to_csv('joint_regression_data/test.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d358d70-619e-42f5-8047-f848b08933ed",
   "metadata": {},
   "source": [
    "## Build datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "375c34cb-d75f-4ec6-a446-d1abe5661c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = scripts.regression.SeqDataset(train, rc=True)\n",
    "val_ds = scripts.regression.SeqDataset(val, rc=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5858721b-241b-44cb-b86a-2070d2b0c3bd",
   "metadata": {},
   "source": [
    "## Build model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51c8b12b-9d13-4f7f-ac9a-a415c1947d8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = scripts.regression.LightningModel(model_type='EnformerModel', loss='Poisson', \n",
    "                                       lr=5e-4, n_tasks=3, dim=384)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0297ea3-f8d1-4476-bc8d-179c4d9e2708",
   "metadata": {},
   "source": [
    "## Train models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a152507-d8f9-40a0-82f2-506262b0ff73",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.train_on_dataset(train_ds, val_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93694897-8d73-4857-a342-5f84c4f80fce",
   "metadata": {},
   "source": [
    "## Evaluate models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75279c83-fa2a-4af6-ac77-c1417e8f9d5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_joint/checkpoints/epoch=8-step=65331.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6470382-fb14-46e9-a04c-f8c4ad59cd0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('joint_regression_data/test.csv', index_col=0)\n",
    "test_ds = scripts.regression.SeqDataset(test, rc=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5afa532d-7e39-41c3-99f7-027affbc77d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "test[['pred_complex', 'pred_defined']] = model.predict_on_dataset(\n",
    "    test_ds, devices=[3], num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34e3c085-db1b-434a-bda0-ecf87e1169a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.pointdensityplot(\n",
    "    test, true_col='exp_complex', pred_col='pred_complex', corrx=4, corry=19)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "094afae8-f9a6-449f-a8bd-41934619bb1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.pointdensityplot(\n",
    "    test, true_col='exp_defined', pred_col='pred_defined', corrx=4, corry=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bc4872b-ba9c-4e60-ac37-4031cc42e0e7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e16eb1e-2211-4e70-8beb-f04959f06949",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50b2fd79-b98c-4af3-a8bf-610d38425c6c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
