{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc05a9ea-5c85-4ba6-b8ec-9c31fbc0f72b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "from plotnine import *\n",
    "sys.path.append('../../')\n",
    "import scripts.regression\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58f77400-152f-4f68-9844-65adc5e4596d",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('regression_lm_data/train.csv', index_col=0, usecols=[0,1,2,3,4])\n",
    "val = pd.read_csv('regression_lm_data/val.csv', index_col=0, usecols=[0,1,2,3,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36c53843-5a41-41cd-8f08-e48a5f35c35e",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = scripts.regression.SeqDataset(train)\n",
    "val_ds = scripts.regression.SeqDataset(val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a3a4e17-2924-45eb-9e4b-50bc4058b232",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = grelu.model.run.LightningModel(model_type='EnformerPretrainedModel', loss='MSE', \n",
    "                                       lr=2e-4,\n",
    "                                       n_tasks=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26e44917-4a2a-4aaa-97df-26b1796b2de6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.train_on_dataset(train_ds, val_ds, batch_size=128, num_workers=16)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "747846a0-618a-4985-a4b9-4f9d0b4ca0fb",
   "metadata": {},
   "source": [
    "## Evaluate model on the test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "549a79c5-43a1-4720-bfc2-23ef98243cd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = grelu.model.run.LightningModel.load_from_checkpoint(\n",
    "    'reg_lm/lightning_logs/version_0/checkpoints/epoch=7-step=20968.ckpt'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53470cda-6da2-411c-8a96-308db11b26cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('regression_lm_data/test.csv', index_col=0, usecols=[0,1,2,3,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e865e7d2-0e3b-4703-bb61-d389096d06ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ds = scripts.regression.SeqDataset(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "313142e4-daed-4568-9323-c0dc3111fd1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "test[['hpred', 'kpred', 'spred']] = model.predict_on_dataset(\n",
    "    test_ds, devices=[7], batch_size=256, num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "527397fa-4521-46d9-92de-23e3c2f52cdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.pointdensityplot(\n",
    "    test, true_col='HepG2_mean', pred_col='hpred', corrx=1, corry=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59ee9d92-b9cc-4579-8bc3-fe7a37c874b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.pointdensityplot(\n",
    "    test, true_col='K562_mean', pred_col='kpred', corrx=1, corry=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5148dae3-9d81-42bb-9289-5a811c10feef",
   "metadata": {},
   "outputs": [],
   "source": [
    "scripts.viz.pointdensityplot(\n",
    "    test, true_col='SKNSH_mean', pred_col='spred', corrx=1, corry=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "748f8d59-bcab-40de-a7a5-c287d62bf555",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
