{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc05a9ea-5c85-4ba6-b8ec-9c31fbc0f72b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import grelu\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "from plotnine import *\n",
    "\n",
    "sys.path.append('../../')\n",
    "import scripts.regression\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58f77400-152f-4f68-9844-65adc5e4596d",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('regression_separate_data/train.csv', index_col=0, usecols=[0,1,2,3,4])\n",
    "val = pd.read_csv('regression_separate_data/val.csv', index_col=0, usecols=[0,1,2,3,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36c53843-5a41-41cd-8f08-e48a5f35c35e",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = scripts.regression.SeqDataset(train)\n",
    "val_ds = scripts.regression.SeqDataset(val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f58acc6f-ef06-4b76-8813-647142095c79",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = scripts.regression.LightningModel(model_type=EnformerPretrainedModel,\n",
    "    n_tasks=3, loss='MSE', lr=2e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a3a4e17-2924-45eb-9e4b-50bc4058b232",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train_on_dataset(train_ds, val_ds, batch_size=128, num_workers=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cd87b48-c692-4ed8-998f-ab03c2b9ca60",
   "metadata": {},
   "source": [
    "## Evaluate model on the test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "549a79c5-43a1-4720-bfc2-23ef98243cd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_separate/lightning_logs/version_0/checkpoints/epoch=4-step=4635.ckpt'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31c18aad-737a-489b-9b62-8a536130300e",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('regression_separate_data/test.csv', index_col=0, usecols=[0,1,2,3,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df23b93a-2600-47d4-b51c-ca522d84a43f",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ds = scripts.regression.SeqDataset(test[['Sequence']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcab23e5-618e-4e16-8e21-c5070ef68137",
   "metadata": {},
   "outputs": [],
   "source": [
    "test[['hpred', 'kpred', 'spred']] = model.predict_on_dataset(\n",
    "    test_ds, devices=[4], batch_size=256, num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f36ad928-36d9-4221-8dbc-77e7c3147b50",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.pointdensityplot(\n",
    "    test, true_col='HepG2_mean', pred_col='hpred', corrx=1, corry=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3219a690-0d10-4837-bb69-54a6dd9128af",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.pointdensityplot(\n",
    "    test, true_col='K562_mean', pred_col='kpred', corrx=1, corry=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c00f68e0-71cf-400b-9625-b9a7eb09e57b",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.pointdensityplot(\n",
    "    test, true_col='SKNSH_mean', pred_col='spred', corrx=1, corry=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be989df8-e5fc-4bfc-9cdc-23a6f4052572",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
