{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c1e5673-2d58-4c88-b4ca-20d950239619",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from plotnine import *\n",
    "import torch\n",
    "from torch import nn\n",
    "import os\n",
    "sys.path.append('../../')\n",
    "import scripts.regression, scripts.evolve"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f437fac-404a-4121-8104-084ce06a503b",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a12be8ab-c607-4ac8-9398-15ba55aaa5ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_joint/checkpoints/epoch=8-step=65331.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45f2c2a0-bcc6-4c66-983c-638f039d0184",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmodel = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_complex/lightning_logs/version_2/checkpoints/epoch=1-step=30824.ckpt')\n",
    "dmodel = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_defined/lightning_logs/version_1/checkpoints/epoch=8-step=80568.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b89e8c0-16c5-46f5-a87d-784099210563",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('lm_data/test.csv', index_col=0, dtype='str', usecols=(0, 1,6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf1fa294-79e7-4514-a207-2004d1f88d1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#os.mkdir('synthetic_promoters')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87d2f9c9-a8ab-4a9d-8b2c-ae63f07b37a3",
   "metadata": {},
   "source": [
    "## Get starting sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "249e3a04-bbf3-4266-9a3e-2d405a35ed59",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_seqs = test[test.label=='00'].seq.sample(500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62951042-eb3e-404c-8d05-1da34884a49f",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_seqs.to_csv('synthetic_promoters/start_seqs.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31080028-2955-4b0f-af62-5f66295a2890",
   "metadata": {},
   "source": [
    "## Evolve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12ba2f4e-3a55-4d97-8816-07b15c216b33",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "nn_df = pd.DataFrame()\n",
    "for i, seq in enumerate(start_seqs.tolist()):\n",
    "    df = scripts.evolve.evolve([seq], model, to_max=[0,1], to_min=None, device=0)\n",
    "    df['start_seq'] = i\n",
    "    nn_df = pd.concat([nn_df, df])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e53d8b7-a266-4bc5-8fad-8b7a25b33f50",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = scripts.regression.SeqDataset(nn_df.seq.tolist())\n",
    "nn_df['Complex'] = cmodel.predict_on_dataset(ds, devices=[0], num_workers=8)\n",
    "nn_df['Defined'] = dmodel.predict_on_dataset(ds, devices=[0], num_workers=8)\n",
    "nn_df = nn_df[['seq', 'iter', 'start_seq', 'Complex', 'Defined']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "669773fb-902b-4fc1-8eb1-6c1657e985c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "nn_df.reset_index(drop=True).to_csv('synthetic_promoters/DE_nn.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df8c051a-4348-47e8-8294-248d9c5a0e9a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
