{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e139e1f-f2d4-42a3-8ab9-a4459d93f363",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "import torch\n",
    "sys.path.insert(0, \"../../regLM/\")\n",
    "import reglm.dataset, reglm.lightning\n",
    "\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)\n",
    "torch.cuda.manual_seed_all(0)\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"6\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c0e7a9e-b42c-443c-b386-915753cb79d2",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "991850c1-660b-45e8-bc1b-971c7e496754",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('lm_data/train.csv', index_col=0, usecols=(0,1,6), dtype='str')\n",
    "val = pd.read_csv('lm_data/val.csv', index_col=0, usecols=(0,1,6), dtype='str')\n",
    "\n",
    "val.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "384f8002-e5c5-4ecf-a149-b1c36a0abe83",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa2136d9-64ef-4c5c-8620-ab503a950c87",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = reglm.dataset.CharDataset(train.nt_sequence.tolist(), train.label.tolist(), rc=False)\n",
    "val_ds = reglm.dataset.CharDataset(val.nt_sequence.tolist(), val.label.tolist(), rc=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2bab8ad-a11a-46ca-be68-3329110676de",
   "metadata": {},
   "source": [
    "## Build model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cdb515b-992e-4a5c-a16b-ebf65264c2b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = reglm.lightning.LightningModel(\n",
    "    logger='csv',\n",
    "    ckpt_dir='/code/sequence_modeling_playground/hyena/checkpoints/hyenadna-medium-160k-seqlen/')\n",
    "model.count_params()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72b7fd07-b0fc-46a5-aec0-04d5697148ce",
   "metadata": {},
   "source": [
    "## Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "517da2db-b3e8-4ee6-b637-dcb44d58752c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.train_on_dataset(train_ds, val_ds, max_epochs=100,\n",
    "            batch_size=128, num_workers=16, device=0, \n",
    "            val_check_interval=1000,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8ed7b6f-a7c5-4108-ae1f-15b34101d710",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
