{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0b119b7e",
   "metadata": {},
   "source": [
    "### Siamese Network Training with Betti Vectorization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "47d6ddeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import optim\n",
    "import pickle\n",
    "from src.ml_models.siamese import (Siamese, ContrastiveLoss, TripletMarginLoss,\n",
    "                                   get_anchor_samples, generate_data_pairs, generate_data_triplets, \n",
    "                                   split_into_batches, train_model, produce_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "83b3e195-a436-4c4d-959e-90b72754c22a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda:1'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = 'cuda:1' if torch.cuda.is_available() else 'cpu'\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc50db6c",
   "metadata": {},
   "source": [
    "#### Load data: Partial Charge filtration (superlevel) on the molecule Followed by Betti Vectorization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "94acb9fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../data/ClevesJain_TopologyFeatures/partial_charge_superlevel_betti.pickle', 'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "\n",
    "train_x, train_y, test_x, test_y = data[\"train_x\"], data[\"train_y\"], data[\"test_x\"], data[\"test_y\"]\n",
    "anchor_x, anchor_y = get_anchor_samples(train_x, train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6c988bb2-9fa4-47d4-b811-deabb086a55b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "57 1092\n",
      "(20, 13) (20, 13)\n"
     ]
    }
   ],
   "source": [
    "print(len(train_x), len(test_x))\n",
    "print(train_x[0].shape, test_x[0].shape) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87fcb8a3",
   "metadata": {},
   "source": [
    "#### Model Training - ViT Backbone & Triplet Margin Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f08b5f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#loss_fn = \"circle\"\n",
    "loss_fn = \"triplet_margin\"\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "67176479-2c93-49d8-8d9a-c9df1f03f811",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set, test_set = generate_data_triplets(train_x, train_y, test_x, test_y, anchor_x, anchor_y)\n",
    "dataloaders = split_into_batches(train_set, test_set, batch_size, loss_fn=loss_fn)\n",
    "del train_set, test_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9dfa6aa4-83a4-48cf-9709-d05f03c57861",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/demiran1/.conda/envs/tda/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:371: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.\n",
      "  warnings.warn(\"To get the last learning rate computed by the scheduler, \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training set size:  8298\n",
      "Test set size:  24024\n",
      "Epoch 1/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [02:50<00:00,  1.31s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.16165367\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [05:57<00:00,  1.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.07709484\n",
      "Epoch 2/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:06<00:00,  1.43s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.04848604\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [05:58<00:00,  1.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.07452442\n",
      "Epoch 3/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:06<00:00,  1.43s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.02949787\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [05:56<00:00,  1.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.07361304\n",
      "Epoch 4/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:05<00:00,  1.43s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01937526\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:01<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06951264\n",
      "Epoch 5/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:10<00:00,  1.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01526458\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:03<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06407508\n",
      "Epoch 6/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:07<00:00,  1.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01475529\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [05:59<00:00,  1.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06292001\n",
      "Epoch 7/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:06<00:00,  1.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01400514\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [05:59<00:00,  1.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06321332\n",
      "Epoch 8/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:09<00:00,  1.46s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01347256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:06<00:00,  1.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.05919770\n",
      "Epoch 9/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:10<00:00,  1.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01352143\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:05<00:00,  1.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06103069\n",
      "Epoch 10/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:07<00:00,  1.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01228445\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:01<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06394675\n",
      "Epoch 11/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:10<00:00,  1.46s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01320908\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:06<00:00,  1.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06277978\n",
      "Epoch 12/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:10<00:00,  1.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01170023\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:06<00:00,  1.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06197820\n",
      "Epoch 13/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:10<00:00,  1.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01191317\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:07<00:00,  1.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06415269\n",
      "Epoch 14/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:11<00:00,  1.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01221310\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:07<00:00,  1.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06313287\n",
      "Epoch 15/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:11<00:00,  1.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01179706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:07<00:00,  1.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06022782\n",
      "Epoch 16/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:11<00:00,  1.47s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01356513\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:05<00:00,  1.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06423314\n",
      "Epoch 17/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:07<00:00,  1.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01279134\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:00<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06906651\n",
      "Epoch 18/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:07<00:00,  1.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01038568\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:00<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06070759\n",
      "Epoch 19/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:07<00:00,  1.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01192260\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:00<00:00,  1.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06377448\n",
      "Epoch 20/20, LR 0.00050000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 130/130 [03:07<00:00,  1.44s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.01182734\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 376/376 [06:00<00:00,  1.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 0.06547957\n",
      "Training complete in 183m 33s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# model = Siamese(base_model=\"convnext\", embedding_size=1000)\n",
    "model = Siamese(base_model=\"vision_transformer\", embedding_size=1000)\n",
    "# model = Siamese(base_model=\"resnet\", embedding_size=1000)\n",
    "\n",
    "optimizer = optim.Adam(\n",
    "    filter(lambda p: p.requires_grad, model.parameters()),\n",
    "    lr=0.0005,\n",
    "    eps=1e-8,\n",
    "    weight_decay=0.0005,\n",
    ")\n",
    "model = train_model(model, dataloaders, device, optimizer, lr_decay=False, num_epochs=20, loss_fn=loss_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc743508",
   "metadata": {},
   "source": [
    "#### Enrichment Factor Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d5548d25-2f45-418b-8329-78a56d337579",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1a875b98",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>EF 2%</th>\n",
       "      <th>EF 5%</th>\n",
       "      <th>EF 10%</th>\n",
       "      <th>EF 15%</th>\n",
       "      <th>EF 20%</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>33.333333</td>\n",
       "      <td>13.333333</td>\n",
       "      <td>8.333333</td>\n",
       "      <td>5.555556</td>\n",
       "      <td>4.166667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>6.818182</td>\n",
       "      <td>4.545455</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>3.939394</td>\n",
       "      <td>3.181818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>15.384615</td>\n",
       "      <td>13.846154</td>\n",
       "      <td>9.230769</td>\n",
       "      <td>6.153846</td>\n",
       "      <td>5.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>41.666667</td>\n",
       "      <td>16.666667</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>6.666667</td>\n",
       "      <td>5.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>20.000000</td>\n",
       "      <td>12.000000</td>\n",
       "      <td>6.000000</td>\n",
       "      <td>4.000000</td>\n",
       "      <td>4.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>25.000000</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>7.500000</td>\n",
       "      <td>5.000000</td>\n",
       "      <td>3.750000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>30.000000</td>\n",
       "      <td>12.000000</td>\n",
       "      <td>8.000000</td>\n",
       "      <td>5.333333</td>\n",
       "      <td>4.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>50.000000</td>\n",
       "      <td>20.000000</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>6.666667</td>\n",
       "      <td>5.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>20.000000</td>\n",
       "      <td>12.000000</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>6.666667</td>\n",
       "      <td>5.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>14.285714</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>5.714286</td>\n",
       "      <td>5.238095</td>\n",
       "      <td>3.928571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>21.428571</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>5.714286</td>\n",
       "      <td>3.809524</td>\n",
       "      <td>3.214286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>30.000000</td>\n",
       "      <td>14.000000</td>\n",
       "      <td>8.000000</td>\n",
       "      <td>6.000000</td>\n",
       "      <td>5.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>44.444444</td>\n",
       "      <td>17.777778</td>\n",
       "      <td>8.888889</td>\n",
       "      <td>6.666667</td>\n",
       "      <td>5.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>25.000000</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>7.000000</td>\n",
       "      <td>4.666667</td>\n",
       "      <td>4.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>16.666667</td>\n",
       "      <td>13.333333</td>\n",
       "      <td>8.333333</td>\n",
       "      <td>5.777778</td>\n",
       "      <td>4.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>17.391304</td>\n",
       "      <td>10.434783</td>\n",
       "      <td>5.652174</td>\n",
       "      <td>4.927536</td>\n",
       "      <td>3.695652</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>36.363636</td>\n",
       "      <td>18.181818</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>6.666667</td>\n",
       "      <td>5.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>17.857143</td>\n",
       "      <td>8.571429</td>\n",
       "      <td>8.571429</td>\n",
       "      <td>5.714286</td>\n",
       "      <td>4.642857</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>13.333333</td>\n",
       "      <td>6.666667</td>\n",
       "      <td>4.666667</td>\n",
       "      <td>4.000000</td>\n",
       "      <td>3.333333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>30.000000</td>\n",
       "      <td>12.000000</td>\n",
       "      <td>6.000000</td>\n",
       "      <td>5.333333</td>\n",
       "      <td>5.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>38.888889</td>\n",
       "      <td>20.000000</td>\n",
       "      <td>10.000000</td>\n",
       "      <td>6.666667</td>\n",
       "      <td>5.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>28.571429</td>\n",
       "      <td>11.428571</td>\n",
       "      <td>5.714286</td>\n",
       "      <td>3.809524</td>\n",
       "      <td>3.571429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>26.201542</td>\n",
       "      <td>12.581181</td>\n",
       "      <td>7.650884</td>\n",
       "      <td>5.420858</td>\n",
       "      <td>4.340210</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          EF 2%      EF 5%     EF 10%    EF 15%    EF 20%\n",
       "0     33.333333  13.333333   8.333333  5.555556  4.166667\n",
       "1      6.818182   4.545455   5.000000  3.939394  3.181818\n",
       "2     15.384615  13.846154   9.230769  6.153846  5.000000\n",
       "3     41.666667  16.666667  10.000000  6.666667  5.000000\n",
       "4     20.000000  12.000000   6.000000  4.000000  4.000000\n",
       "5     25.000000  10.000000   7.500000  5.000000  3.750000\n",
       "6     30.000000  12.000000   8.000000  5.333333  4.000000\n",
       "7     50.000000  20.000000  10.000000  6.666667  5.000000\n",
       "8     20.000000  12.000000  10.000000  6.666667  5.000000\n",
       "9     14.285714  10.000000   5.714286  5.238095  3.928571\n",
       "10    21.428571  10.000000   5.714286  3.809524  3.214286\n",
       "11    30.000000  14.000000   8.000000  6.000000  5.000000\n",
       "12    44.444444  17.777778   8.888889  6.666667  5.000000\n",
       "13    25.000000  10.000000   7.000000  4.666667  4.500000\n",
       "14    16.666667  13.333333   8.333333  5.777778  4.500000\n",
       "15    17.391304  10.434783   5.652174  4.927536  3.695652\n",
       "16    36.363636  18.181818  10.000000  6.666667  5.000000\n",
       "17    17.857143   8.571429   8.571429  5.714286  4.642857\n",
       "18    13.333333   6.666667   4.666667  4.000000  3.333333\n",
       "19    30.000000  12.000000   6.000000  5.333333  5.000000\n",
       "20    38.888889  20.000000  10.000000  6.666667  5.000000\n",
       "21    28.571429  11.428571   5.714286  3.809524  3.571429\n",
       "22          NaN        NaN        NaN       NaN       NaN\n",
       "mean  26.201542  12.581181   7.650884  5.420858  4.340210"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC score: 0.847493465015483\n"
     ]
    }
   ],
   "source": [
    "factors = [0.02, 0.05, 0.1, 0.15, 0.2]\n",
    "produce_results(model, test_x, test_y, anchor_x, device, loss_fn, factors, batch_size=1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tda",
   "language": "python",
   "name": "tda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
