{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0b119b7e",
   "metadata": {},
   "source": [
    "### Siamese Network Training with Betti Vectorization: TARGET AKT1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "47d6ddeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from torch import optim\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from src.ml_models.siamese import (Siamese, ContrastiveLoss, TripletMarginLoss,\n",
    "                                   get_anchor_samples, generate_data_pairs, generate_data_triplets, \n",
    "                                   split_into_batches, train_model, produce_results, produce_results_alternative)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "83b3e195-a436-4c4d-959e-90b72754c22a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda:1'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = 'cuda:1' if torch.cuda.is_available() else 'cpu'\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7ef9bedc-0252-4c24-9d8f-f52a0bacf591",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a050d06-044e-4813-9e01-42a6d71b74d9",
   "metadata": {},
   "source": [
    "#### Load data: Atom Weight filtration Followed by Betti Vectorization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8dec9177-d070-482c-a682-88023de04e12",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 17001/17001 [03:27<00:00, 81.79it/s] \n"
     ]
    }
   ],
   "source": [
    "X, y = [], []\n",
    "path = \"../data/DUDE-Diverse_TopologyFeatures/atom_weight_superlevel_betti/target_akt1/\"\n",
    "\n",
    "for file in tqdm(os.listdir(path)):\n",
    "    if file.endswith('pkl'):\n",
    "        with open(os.path.join(path, file), 'rb') as f:\n",
    "            X.append(pickle.load(f))\n",
    "        target_label = 0 if \"active\" in file else 1\n",
    "        y.append(target_label)\n",
    "\n",
    "active_ind = [i for i, label in enumerate(y) if label == 0]\n",
    "decoy_ind = [i for i, label in enumerate(y) if label == 1]\n",
    "\n",
    "# shuffle the lists\n",
    "random.shuffle(active_ind)\n",
    "random.shuffle(decoy_ind)\n",
    "\n",
    "# Use 80% of the actives for training \n",
    "num_actives = len(active_ind)\n",
    "num_training = int(0.8 * num_actives)\n",
    "active_training_ind = active_ind[:num_training] \n",
    "active_test_ind = active_ind[num_training:] \n",
    "\n",
    "# Use same number of decoys as actives in training (avoid imbalance)\n",
    "decoy_training_ind = decoy_ind[:num_training] \n",
    "decoy_test_ind = decoy_ind[num_training:] \n",
    "\n",
    "# training and test index lists\n",
    "training_ind = active_training_ind + decoy_training_ind\n",
    "test_ind = active_test_ind + decoy_test_ind\n",
    "\n",
    "train_x, train_y = [X[i] for i in training_ind], [y[i] for i in training_ind]\n",
    "test_x, test_y = [X[i] for i in test_ind], [y[i] for i in test_ind]\n",
    "\n",
    "anchor_x, anchor_y = get_anchor_samples(train_x, train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6c988bb2-9fa4-47d4-b811-deabb086a55b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "676 16324\n",
      "(22, 30) (22, 30)\n"
     ]
    }
   ],
   "source": [
    "print(len(train_x), len(test_x))\n",
    "print(train_x[0].shape, test_x[0].shape) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87fcb8a3",
   "metadata": {},
   "source": [
    "#### Model Training - ConvNeXt Backbone & Triplet Margin Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f08b5f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#loss_fn = \"circle\"\n",
    "loss_fn = \"triplet_margin\"\n",
    "#loss_fn = \"contrastive\"\n",
    "batch_size = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "67176479-2c93-49d8-8d9a-c9df1f03f811",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set, test_set = generate_data_triplets(train_x, train_y, test_x, test_y, anchor_x, anchor_y, generate_large=False, mul=20)\n",
    "#train_set, test_set = generate_data_pairs(train_x, train_y, test_x, test_y, anchor_x, anchor_y)\n",
    "dataloaders = split_into_batches(train_set, test_set, batch_size, loss_fn=loss_fn)\n",
    "del train_set, test_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9dfa6aa4-83a4-48cf-9709-d05f03c57861",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/demiran1/.conda/envs/tda/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:371: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.\n",
      "  warnings.warn(\"To get the last learning rate computed by the scheduler, \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training set size:  27040\n",
      "Test set size:  32648\n",
      "Epoch 1/1, LR 0.00010000\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6760/6760 [07:14<00:00, 15.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.14417502\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8162/8162 [02:06<00:00, 64.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test Loss: 1.16191815\n",
      "Training complete in 9m 25s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model = Siamese(base_model=\"convnext\", embedding_size=5000, deeper=False, finetune=True)\n",
    "# model = Siamese(base_model=\"vision_transformer\", embedding_size=1000, deeper=False, finetune=False)\n",
    "# model = Siamese(base_model=\"resnet\", embedding_size=1000, deeper=False, finetune=False)\n",
    "\n",
    "optimizer = optim.Adam(\n",
    "    filter(lambda p: p.requires_grad, model.parameters()),\n",
    "    lr=0.0001,\n",
    "    eps=1e-8,\n",
    "    weight_decay=0.0005,\n",
    ")\n",
    "model = train_model(model, dataloaders, device, optimizer, lr_decay=False, num_epochs=1, loss_fn=loss_fn)\n",
    "# model = train_model(model, dataloaders, device, optimizer, lr_decay=True, num_epochs=4, loss_fn=loss_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc743508",
   "metadata": {},
   "source": [
    "#### Enrichment Factor Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1a875b98",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>EF 1%</th>\n",
       "      <th>EF 2%</th>\n",
       "      <th>EF 5%</th>\n",
       "      <th>EF 10%</th>\n",
       "      <th>EF 15%</th>\n",
       "      <th>EF 20%</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>60.0</td>\n",
       "      <td>37.058824</td>\n",
       "      <td>15.764706</td>\n",
       "      <td>8.470588</td>\n",
       "      <td>5.803922</td>\n",
       "      <td>4.529412</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>60.0</td>\n",
       "      <td>37.058824</td>\n",
       "      <td>15.764706</td>\n",
       "      <td>8.470588</td>\n",
       "      <td>5.803922</td>\n",
       "      <td>4.529412</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      EF 1%      EF 2%      EF 5%    EF 10%    EF 15%    EF 20%\n",
       "0      60.0  37.058824  15.764706  8.470588  5.803922  4.529412\n",
       "1       NaN        NaN        NaN       NaN       NaN       NaN\n",
       "mean   60.0  37.058824  15.764706  8.470588  5.803922  4.529412"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC score: 0.9060098600681731\n"
     ]
    }
   ],
   "source": [
    "factors = [0.01, 0.02, 0.05, 0.1, 0.15, 0.2]\n",
    "produce_results(model, test_x, test_y, anchor_x, device, loss_fn, factors, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1465252-0b81-4ff5-b3e8-0397784e03de",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tda",
   "language": "python",
   "name": "tda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
