{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc2b39f6-fb50-4b4d-ba69-1cdf0503667b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "import torch\n",
    "import itertools\n",
    "sys.path.insert(0, \"../../\")\n",
    "import scripts.evolve, scripts.regression"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56ccc59b-fb23-480b-8eef-bd16646553fc",
   "metadata": {},
   "source": [
    "## Load models and start sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a568f821-f4a1-4e76-aff5-f7806005020a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_joint/checkpoints/epoch=8-step=65331.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bce7390f-b7fb-4936-8443-159b8948a3a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmodel = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_complex/lightning_logs/version_2/checkpoints/epoch=1-step=30824.ckpt')\n",
    "dmodel = scripts.regression.LightningModel.load_from_checkpoint(\n",
    "    'reg_defined/lightning_logs/version_1/checkpoints/epoch=8-step=80568.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9087a3e-75c1-48a0-a74a-4f0da2663ec4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "start_seqs = pd.read_csv('synthetic_promoters/start_seqs.csv', index_col=0).seq.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d248819f-7bf5-4cd5-bd6b-70fa720c4beb",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(torch.device(4))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "429c4154-c0bb-4f20-94fa-06ebcb82361f",
   "metadata": {},
   "source": [
    "## Ledidi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fef7daeb-7424-41c5-853f-fc796ce1b4ab",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "l_all = pd.DataFrame()\n",
    "for i, start_seq in enumerate(start_seqs):\n",
    "    out_seq = scripts.evolve.ledidi(seq, model, to_max=[0,1],to_min=None, device=0, num_workers=8)\n",
    "    curr_l = pd.DataFrame({'seq':out_seq})\n",
    "    curr_l['start_seq'] = i\n",
    "    l_all = pd.concat([l_all, curr_l])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4254689b-a2a8-4987-ac66-37cd2b1c17b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = scripts.regression.SeqDataset(l_all.seq.tolist())\n",
    "l_all['Complex'] = cmodel.predict_on_dataset(ds, devices=[0], num_workers=8)\n",
    "l_all['Defined'] = dmodel.predict_on_dataset(ds, devices=[0], num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "414eb783-c2cf-4d99-b1dc-ef3f7968fdb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "l_all.reset_index(drop=True).to_csv('synthetic_promoters/LE_nn.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "631b0f7b-85f5-424c-946c-a7ecaf4532b6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
