{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e972463e-94f1-4afe-aa22-56a2ad2b3cac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import importlib\n",
    "import sys\n",
    "import torch\n",
    "\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "sys.path.append(\"../../regLM/\")\n",
    "import reglm.dataset, reglm.lightning\n",
    "\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)\n",
    "torch.cuda.manual_seed_all(0)\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
    "DEVICE = torch.device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f4a33b-27d1-4734-bb8c-e5ec1d7973fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('lm_data/train.csv', index_col=0, usecols=(0, 1, 6), dtype=\"str\")\n",
    "val = pd.read_csv('lm_data/val.csv', index_col=0, usecols=(0, 1, 6), dtype=\"str\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f27c01-2840-47c0-a40f-86afba90e957",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = reglm.dataset.CharDataset(train.seq.tolist(), labels = train.label.tolist())\n",
    "val_ds = reglm.dataset.CharDataset(seqs = val.seq.tolist(), labels = val.label.tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc47905c-9084-42c1-babb-96127190c1c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    " 'd_model': 256,\n",
    " 'n_layer': 8,\n",
    " 'd_inner': 1024,\n",
    " 'vocab_size': 12,\n",
    " 'resid_dropout': 0.0,\n",
    " 'embed_dropout': 0.1,\n",
    " 'fused_mlp': False,\n",
    " 'fused_dropout_add_ln': True,\n",
    " 'residual_in_fp32': True,\n",
    " 'pad_vocab_size_multiple': 8,\n",
    " 'return_hidden_state': True,\n",
    " 'layer': {'emb_dim': 5,\n",
    "  'filter_order': 64,\n",
    "  'local_order': 3,\n",
    "  'l_max': 84,\n",
    "  'modulate': True,\n",
    "  'w': 10,\n",
    "  'lr': 0.0006,\n",
    "  'wd': 0.0,\n",
    "  'lr_pos_emb': 0.0,\n",
    "  '_name_': 'hyena'}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f0d0a08-8ea1-4ed2-a351-5369db09d02a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = reglm.lightning.LightningModel(config=config, lr=3e-4, logger='csv')\n",
    "model.count_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad2d89c4-b1be-4a4a-9652-f3efbdead4a5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model.train_on_dataset(train_ds, val_ds, max_epochs=100,\n",
    "            batch_size=128, num_workers=16, device=0, \n",
    "            val_check_interval=2000,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98573e47-8e91-4772-a13c-413cb2c3ed47",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
