{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.chordmixer import ChordMixerNet\n",
    "from experiments.training_utils import count_params, seed_everything, init_weights, train_epoch, eval_model\n",
    "from experiments.dataloader_utils import DatasetCreator, concater_collate\n",
    "from sklearn.utils.class_weight import compute_class_weight\n",
    "from sklearn.metrics import roc_auc_score, accuracy_score\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch import nn, optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Genbank Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_train = pd.read_pickle(f'experiments/data/Carassius_Labeo_train.pkl')\n",
    "data_test = pd.read_pickle(f'experiments/data/Carassius_Labeo_test.pkl')\n",
    "max_seq_len = max(max(data_train['len']), max(data_test['len']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sequence</th>\n",
       "      <th>label</th>\n",
       "      <th>len</th>\n",
       "      <th>bin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>130668</th>\n",
       "      <td>[1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 3, 0, 0, 3, 0, ...</td>\n",
       "      <td>0</td>\n",
       "      <td>600</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46015</th>\n",
       "      <td>[0, 3, 0, 3, 2, 3, 0, 3, 3, 0, 3, 1, 0, 1, 1, ...</td>\n",
       "      <td>0</td>\n",
       "      <td>1056</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2648</th>\n",
       "      <td>[1, 2, 2, 0, 1, 2, 1, 2, 3, 2, 2, 1, 1, 0, 3, ...</td>\n",
       "      <td>0</td>\n",
       "      <td>1412</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8205</th>\n",
       "      <td>[2, 0, 2, 3, 2, 1, 1, 1, 1, 2, 3, 2, 3, 2, 4, ...</td>\n",
       "      <td>1</td>\n",
       "      <td>447</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>122869</th>\n",
       "      <td>[0, 3, 2, 1, 3, 0, 0, 0, 0, 2, 3, 0, 1, 3, 3, ...</td>\n",
       "      <td>0</td>\n",
       "      <td>1383</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 sequence  label   len bin\n",
       "130668  [1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 3, 0, 0, 3, 0, ...      0   600   3\n",
       "46015   [0, 3, 0, 3, 2, 3, 0, 3, 3, 0, 3, 1, 0, 1, 1, ...      0  1056   7\n",
       "2648    [1, 2, 2, 0, 1, 2, 1, 2, 3, 2, 2, 1, 1, 0, 3, ...      0  1412   9\n",
       "8205    [2, 0, 2, 3, 2, 1, 1, 1, 1, 2, 3, 2, 3, 2, 4, ...      1   447   2\n",
       "122869  [0, 3, 2, 1, 3, 0, 0, 0, 0, 2, 3, 0, 1, 3, 3, ...      0  1383   9"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "set up device: cuda:0\n",
      "class weights: [0.96255649 1.04047453]\n"
     ]
    }
   ],
   "source": [
    "torch.cuda.set_device(0)\n",
    "device = 'cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu'\n",
    "print('set up device:', device)\n",
    "\n",
    "net = ChordMixerNet(\n",
    "    problem='genbank',\n",
    "    vocab_size=20,\n",
    "    max_seq_len=max_seq_len,\n",
    "    embedding_size=128,\n",
    "    track_size=16,\n",
    "    hidden_size=128,\n",
    "    mlp_dropout=0.0,\n",
    "    layer_dropout=0.0,\n",
    "    n_class=2\n",
    ")\n",
    "net = net.to(device)\n",
    "net.apply(init_weights)\n",
    "\n",
    "\n",
    "class_weights = compute_class_weight('balanced', classes=[0, 1], y=data_train['label'])\n",
    "print('class weights:', class_weights)\n",
    "\n",
    "loss = nn.CrossEntropyLoss(\n",
    "    weight=torch.tensor(class_weights, dtype=torch.float).to(device),\n",
    "    reduction='mean'\n",
    ")\n",
    "\n",
    "optimizer = optim.Adam(\n",
    "    net.parameters(),\n",
    "    lr=0.0001\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = DatasetCreator(\n",
    "    df=data_train,\n",
    "    batch_size=2,\n",
    "    var_len=True\n",
    ")\n",
    "\n",
    "trainloader = DataLoader(\n",
    "    trainset,\n",
    "    batch_size=2,\n",
    "    shuffle=False,\n",
    "    collate_fn=concater_collate,\n",
    "    drop_last=False,\n",
    "    num_workers=4\n",
    ")\n",
    "\n",
    "\n",
    "# Prepare the testing loader\n",
    "testset = DatasetCreator(\n",
    "    df=data_test,\n",
    "    batch_size=2,\n",
    "    var_len=True\n",
    ")\n",
    "\n",
    "testloader = DataLoader(\n",
    "    testset,\n",
    "    batch_size=2,\n",
    "    shuffle=False,\n",
    "    collate_fn=concater_collate,\n",
    "    drop_last=False,\n",
    "    num_workers=4\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(config, net, optimizer, loss, trainloader, device, log_every=50, scheduler=None, problem=None):\n",
    "    net.train()\n",
    "    running_loss = 0.0\n",
    "    n_items_processed = 0\n",
    "    num_batches = len(trainloader)\n",
    "    for idx, (X, Y, length, bin) in tqdm(enumerate(trainloader), total=num_batches):\n",
    "        if problem == 'adding':\n",
    "            X = X.float().to(device)\n",
    "            Y = Y.float().to(device)\n",
    "            output = net(X, length).squeeze()\n",
    "            output = loss(output, Y)\n",
    "        else:\n",
    "            X = X.to(device)\n",
    "            Y = Y.to(device)            \n",
    "            output = net(X, length)\n",
    "            output = loss(output, Y)\n",
    "\n",
    "        output.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        if scheduler:\n",
    "            scheduler.step()\n",
    "\n",
    "        running_loss += output.item()\n",
    "        n_items_processed += len(length)\n",
    "\n",
    "    total_loss = running_loss / num_batches\n",
    "    print(f'Training loss after epoch: {total_loss}')\n",
    "\n",
    "def eval_model(config, net, valloader, metric, device, problem) -> float:\n",
    "    net.eval()\n",
    "\n",
    "    preds = []\n",
    "    targets = []\n",
    "    bins = []\n",
    "\n",
    "    num_batches = len(valloader)\n",
    "    for idx, (X, Y, length, bin) in tqdm(enumerate(valloader), total=num_batches, position=0, leave=True, ascii=False):\n",
    "        if problem == 'adding':\n",
    "            X = X.float().to(device)\n",
    "            Y = Y.float().to(device)\n",
    "            output = net(X, length).squeeze()\n",
    "            predicted = output\n",
    "        else:\n",
    "            X = X.to(device)\n",
    "            Y = Y.to(device)  \n",
    "            output = net(X, length) \n",
    "            _, predicted = output.max(1)\n",
    "        \n",
    "        targets.extend(Y.detach().cpu().numpy().flatten())\n",
    "        preds.extend(predicted.detach().cpu().numpy().flatten())\n",
    "        bins.extend(bin)\n",
    "\n",
    "    total_metric = metric(preds, targets)\n",
    "    return total_metric\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting epoch 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4458/4458 [01:23<00:00, 53.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.5930328630110909\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:11<00:00, 120.57it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 completed. Test ROCAUC: 0.8715695476814205\n",
      "Starting epoch 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:20<00:00, 55.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.31280290801909577\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:11<00:00, 119.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 completed. Test ROCAUC: 0.902502478755048\n",
      "Starting epoch 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:21<00:00, 54.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.2229982015722905\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:14<00:00, 99.49it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3 completed. Test ROCAUC: 0.9140791183062801\n",
      "Starting epoch 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:22<00:00, 54.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.16965537135811298\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:14<00:00, 99.12it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4 completed. Test ROCAUC: 0.9153168445214707\n",
      "Starting epoch 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:21<00:00, 54.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.11841853988325589\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:14<00:00, 98.38it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5 completed. Test ROCAUC: 0.9261526615547487\n",
      "Starting epoch 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:21<00:00, 54.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.07616942251129151\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:14<00:00, 98.86it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6 completed. Test ROCAUC: 0.9241066430607808\n",
      "Starting epoch 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:23<00:00, 53.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.05381972214043256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:10<00:00, 129.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7 completed. Test ROCAUC: 0.9215720136437474\n",
      "Starting epoch 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:21<00:00, 54.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0501274761766566\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:13<00:00, 99.77it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8 completed. Test ROCAUC: 0.923776822560785\n",
      "Starting epoch 9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:20<00:00, 55.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.03865994253573443\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:13<00:00, 101.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9 completed. Test ROCAUC: 0.9302007666756696\n",
      "Starting epoch 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:20<00:00, 55.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.02962015546689079\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:11<00:00, 118.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10 completed. Test ROCAUC: 0.9301623595598033\n",
      "Starting epoch 11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:20<00:00, 55.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.02706635002271226\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:11<00:00, 121.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11 completed. Test ROCAUC: 0.9305088746812644\n",
      "Starting epoch 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:21<00:00, 54.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.023921714036276322\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:10<00:00, 128.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 12 completed. Test ROCAUC: 0.9218935134056078\n",
      "Starting epoch 13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:22<00:00, 54.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.025380650376584\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:11<00:00, 123.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13 completed. Test ROCAUC: 0.9215460896538118\n",
      "Starting epoch 14\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4458/4458 [01:20<00:00, 55.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.013640293765500106\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:10<00:00, 129.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14 completed. Test ROCAUC: 0.9316046498177037\n",
      "Starting epoch 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 4458/4458 [01:22<00:00, 54.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.018828823426635284\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 1395/1395 [00:12<00:00, 108.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15 completed. Test ROCAUC: 0.9339238501446672\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(15):\n",
    "    print(f'Starting epoch {epoch+1}')\n",
    "    config=None\n",
    "    train_epoch(config, net, optimizer, loss, trainloader, device=device, log_every=10000, problem='genbank')\n",
    "    test_roc_auc = eval_model(config, net, testloader, metric=roc_auc_score, device=device, problem='genbank') \n",
    "    print(f'Epoch {epoch+1} completed. Test ROCAUC: {test_roc_auc}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_train = pd.read_pickle(f'experiments/data/adding_200_train.pkl')\n",
    "data_test = pd.read_pickle(f'experiments/data/adding_200_test.pkl')\n",
    "max_seq_len = max(max(data_train['len']), max(data_test['len']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sequence</th>\n",
       "      <th>label</th>\n",
       "      <th>len</th>\n",
       "      <th>bin</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>13414</th>\n",
       "      <td>[[-0.08982194422382483, 0.0], [-0.841596420741...</td>\n",
       "      <td>0.241234</td>\n",
       "      <td>476</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>56779</th>\n",
       "      <td>[[-0.6032521601163512, 0.0], [-0.3534035945595...</td>\n",
       "      <td>0.557188</td>\n",
       "      <td>296</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11976</th>\n",
       "      <td>[[0.6875824449807342, 0.0], [0.732433410329577...</td>\n",
       "      <td>0.489890</td>\n",
       "      <td>205</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29866</th>\n",
       "      <td>[[-0.11423020701527808, 0.0], [-0.743413113459...</td>\n",
       "      <td>0.169627</td>\n",
       "      <td>178</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1418</th>\n",
       "      <td>[[-0.9588989728128303, 0.0], [-0.0150490739209...</td>\n",
       "      <td>0.605129</td>\n",
       "      <td>186</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                sequence     label  len bin\n",
       "13414  [[-0.08982194422382483, 0.0], [-0.841596420741...  0.241234  476   6\n",
       "56779  [[-0.6032521601163512, 0.0], [-0.3534035945595...  0.557188  296   4\n",
       "11976  [[0.6875824449807342, 0.0], [0.732433410329577...  0.489890  205   2\n",
       "29866  [[-0.11423020701527808, 0.0], [-0.743413113459...  0.169627  178   1\n",
       "1418   [[-0.9588989728128303, 0.0], [-0.0150490739209...  0.605129  186   2"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "set up device: cuda:1\n"
     ]
    }
   ],
   "source": [
    "torch.cuda.set_device(1)\n",
    "device = 'cuda:{}'.format(1) if torch.cuda.is_available() else 'cpu'\n",
    "print('set up device:', device)\n",
    "\n",
    "net = ChordMixerNet(\n",
    "    problem='adding',\n",
    "    vocab_size=1,\n",
    "    max_seq_len=max_seq_len,\n",
    "    embedding_size=196,\n",
    "    track_size=16,\n",
    "    hidden_size=128,\n",
    "    mlp_dropout=0.0,\n",
    "    layer_dropout=0.0,\n",
    "    n_class=1\n",
    ")\n",
    "net = net.to(device)\n",
    "net.apply(init_weights)\n",
    "\n",
    "loss = nn.MSELoss()\n",
    "\n",
    "optimizer = optim.Adam(\n",
    "    net.parameters(),\n",
    "    lr=0.0003\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = DatasetCreator(\n",
    "    df=data_train,\n",
    "    batch_size=20,\n",
    "    var_len=True\n",
    ")\n",
    "\n",
    "trainloader = DataLoader(\n",
    "    trainset,\n",
    "    batch_size=20,\n",
    "    shuffle=False,\n",
    "    collate_fn=concater_collate,\n",
    "    drop_last=False,\n",
    "    num_workers=4\n",
    ")\n",
    "\n",
    "\n",
    "# Prepare the testing loader\n",
    "testset = DatasetCreator(\n",
    "    df=data_test,\n",
    "    batch_size=20,\n",
    "    var_len=True\n",
    ")\n",
    "\n",
    "testloader = DataLoader(\n",
    "    testset,\n",
    "    batch_size=20,\n",
    "    shuffle=False,\n",
    "    collate_fn=concater_collate,\n",
    "    drop_last=False,\n",
    "    num_workers=4\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting epoch 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2493/2493 [03:26<00:00, 12.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.027713460761252782\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:11<00:00, 23.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 completed. Test accuracy: 0.2978515625\n",
      "Starting epoch 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:29<00:00, 11.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.009715184406176578\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:13<00:00, 18.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 completed. Test accuracy: 0.39140625\n",
      "Starting epoch 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:30<00:00, 11.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.005981061408553807\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:10<00:00, 23.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3 completed. Test accuracy: 0.610546875\n",
      "Starting epoch 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:26<00:00, 12.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.002842787358276122\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:10<00:00, 23.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4 completed. Test accuracy: 0.730078125\n",
      "Starting epoch 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:29<00:00, 11.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0016909314767512926\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:14<00:00, 17.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5 completed. Test accuracy: 0.879296875\n",
      "Starting epoch 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:31<00:00, 11.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.001128642675652514\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:10<00:00, 23.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6 completed. Test accuracy: 0.8083984375\n",
      "Starting epoch 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:27<00:00, 12.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0012362024246399361\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:10<00:00, 23.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7 completed. Test accuracy: 0.9296875\n",
      "Starting epoch 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:27<00:00, 12.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0006878992494961187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:11<00:00, 22.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8 completed. Test accuracy: 0.9330078125\n",
      "Starting epoch 9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:27<00:00, 12.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0006834379984630572\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:10<00:00, 23.58it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9 completed. Test accuracy: 0.908984375\n",
      "Starting epoch 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:25<00:00, 12.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0005689226818744593\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:10<00:00, 23.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10 completed. Test accuracy: 0.9216796875\n",
      "Starting epoch 11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:25<00:00, 12.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0003845247808538325\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:14<00:00, 17.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11 completed. Test accuracy: 0.94609375\n",
      "Starting epoch 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:31<00:00, 11.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0004038411665346416\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:11<00:00, 21.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 12 completed. Test accuracy: 0.976171875\n",
      "Starting epoch 13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:27<00:00, 12.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0003447734195277123\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:14<00:00, 17.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13 completed. Test accuracy: 0.98203125\n",
      "Starting epoch 14\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:27<00:00, 12.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.00026726826554729393\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:10<00:00, 24.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14 completed. Test accuracy: 0.9796875\n",
      "Starting epoch 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:32<00:00, 11.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.00031323987477451144\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:10<00:00, 23.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15 completed. Test accuracy: 0.994140625\n",
      "Starting epoch 16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:25<00:00, 12.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.00023953229634787445\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:14<00:00, 18.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 16 completed. Test accuracy: 0.9953125\n",
      "Starting epoch 17\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:24<00:00, 12.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.00020064286104835277\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:10<00:00, 23.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 17 completed. Test accuracy: 0.9970703125\n",
      "Starting epoch 18\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:26<00:00, 12.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0002564679064256286\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:11<00:00, 23.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18 completed. Test accuracy: 0.9861328125\n",
      "Starting epoch 19\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:30<00:00, 11.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.00015979634654238775\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:10<00:00, 23.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19 completed. Test accuracy: 0.998046875\n",
      "Starting epoch 20\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 2493/2493 [03:24<00:00, 12.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training loss after epoch: 0.0001929583848571533\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 256/256 [00:14<00:00, 17.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20 completed. Test accuracy: 0.997265625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "def accuracy_adding(predictions, target):\n",
    "    predictions = np.array(predictions)\n",
    "    target = np.array(target)\n",
    "    score = (np.abs(predictions - target) < 0.04).mean()\n",
    "    return score\n",
    "\n",
    "for epoch in range(20):\n",
    "    print(f'Starting epoch {epoch+1}')\n",
    "    config=None\n",
    "    train_epoch(config, net, optimizer, loss, trainloader, device=device, log_every=100000, problem='adding')\n",
    "    accuracy = eval_model(config, net, testloader, metric=accuracy_adding, device=device, problem='adding') \n",
    "    print(f'Epoch {epoch+1} completed. Test accuracy: {accuracy}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6.9 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
