{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d129435-6b8c-4591-8d1d-264d8f72beaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "import torch\n",
    "\n",
    "sys.path.insert(0, \"../../regLM/\")\n",
    "import reglm.lightning\n",
    "\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)\n",
    "torch.cuda.manual_seed_all(0)\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"2\"\n",
    "DEVICE = torch.device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98da0b20-e87e-4712-85ef-f1f2421e9bc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = reglm.lightning.LightningModel.load_from_checkpoint(\n",
    "    'lightning_logs/version_15/checkpoints/epoch=24-step=65304.ckpt')\n",
    "model = model.to(torch.device(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a630898-3120-4681-a870-1880121d17a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "seq = []\n",
    "token = []\n",
    "\n",
    "for t in [\"400\", \"040\", \"004\"]:\n",
    "    print(t)\n",
    "    for i in range(500):\n",
    "        token.append(t)\n",
    "        seq.append(model.generate(label=t, sample=True, device=0, max_new_tokens=200))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "decdf2c6-5808-4268-b192-8356a574e9ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({\n",
    "    'Sequence':[x for x in seq],\n",
    "    'label':token,\n",
    "})\n",
    "\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e988a237-6f00-4ecb-a93f-098278b6f545",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('synthetic_enhancers/lm.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91941bd3-0f9f-4af5-9988-4505e67f6959",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
