{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem\n",
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit\n",
    "import networkx as nx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../../code')\n",
    "\n",
    "from metrics import get_hi_metrics\n",
    "import min_vertex_k_cut"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>O=S(=O)(O)CCS(=O)(=O)O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>CC(C)CCS(=O)(=O)O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90</th>\n",
       "      <td>O=S(=O)(O)CCO</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>106</th>\n",
       "      <td>O=S(=O)(O)CO</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>117</th>\n",
       "      <td>O=S(=O)(O)CCCCBr</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40932</th>\n",
       "      <td>COC(=O)c1cc2cc3c(c(O)c2c(=O)o1)OC1(Oc2c(O)c4c(...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40973</th>\n",
       "      <td>CCCCC1C(OCOc2ccccc2)COC(=O)N1C(C)c1ccccc1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41024</th>\n",
       "      <td>CC(C)=CC1CC(C)C2CCC(C)C3C(=O)C(O)=C(C)C(=O)C123</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41026</th>\n",
       "      <td>CCOC(=O)C12C(=O)C(C)CCC1C(C)CC2C=C(C)C</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41106</th>\n",
       "      <td>Cc1ccc(C=C2CN(C)CC3C(c4ccc(C)cc4)=C(C#N)C(=O)N...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>15696 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                  smiles  value\n",
       "4                                 O=S(=O)(O)CCS(=O)(=O)O      0\n",
       "21                                     CC(C)CCS(=O)(=O)O      0\n",
       "90                                         O=S(=O)(O)CCO      0\n",
       "106                                         O=S(=O)(O)CO      0\n",
       "117                                     O=S(=O)(O)CCCCBr      0\n",
       "...                                                  ...    ...\n",
       "40932  COC(=O)c1cc2cc3c(c(O)c2c(=O)o1)OC1(Oc2c(O)c4c(...      0\n",
       "40973          CCCCC1C(OCOc2ccccc2)COC(=O)N1C(C)c1ccccc1      0\n",
       "41024    CC(C)=CC1CC(C)C2CCC(C)C3C(=O)C(O)=C(C)C(=O)C123      0\n",
       "41026             CCOC(=O)C12C(=O)C(C)CCC1C(C)CC2C=C(C)C      0\n",
       "41106  Cc1ccc(C=C2CN(C)CC3C(c4ccc(C)cc4)=C(C#N)C(=O)N...      0\n",
       "\n",
       "[15696 rows x 2 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../data/hi/hiv/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../data/hi/hiv/test_1.csv', index_col=0)\n",
    "\n",
    "train"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Split train into train and val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[13:38:09] WARNING: not removing hydrogen atom without neighbors\n",
      "[13:38:09] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    }
   ],
   "source": [
    "smiles = train['smiles'].to_list()\n",
    "threshold = 0.4\n",
    "\n",
    "neighborhood_graph = min_vertex_k_cut.get_neighborhood_graph(smiles, threshold)\n",
    "main_component, small_components = min_vertex_k_cut.get_main_component(neighborhood_graph)\n",
    "\n",
    "old_nodes_to_new = dict(zip(main_component.nodes(), range(main_component.number_of_nodes())))\n",
    "new_nodes_to_old = {v: k for k, v in old_nodes_to_new.items()}\n",
    "main_component = nx.relabel_nodes(main_component, old_nodes_to_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "coarsed_main_component, node_to_cluster = min_vertex_k_cut.coarse_graph(main_component, 0.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total molecules: 4486\n",
      "Min train size 2243\n",
      "Min test size 448\n",
      "Welcome to the CBC MILP Solver \n",
      "Version: Trunk\n",
      "Build Date: Oct 24 2021 \n",
      "\n",
      "Starting solution of the Linear programming relaxation problem using Primal Simplex\n",
      "\n",
      "Coin0506I Presolve 3471 (-1486) rows, 1486 (0) columns and 8424 (-2972) elements\n",
      "Clp1000I sum of infeasibilities 3.0458e-06 - average 8.77499e-10, 0 fixed columns\n",
      "Coin0506I Presolve 3471 (0) rows, 1486 (0) columns and 8424 (0) elements\n",
      "Clp0029I End of values pass after 1486 iterations\n",
      "Clp0014I Perturbing problem by 0.001% of 0.61265822 - largest nonzero change 2.9985502e-05 ( 0.002395959%) - largest zero change 0\n",
      "Clp0000I Optimal - objective value 4486\n",
      "Clp0000I Optimal - objective value 4486\n",
      "Clp0000I Optimal - objective value 4486\n",
      "Coin0511I After Postsolve, objective 4486, infeasibilities - dual 0 (0), primal 0 (0)\n",
      "Clp0032I Optimal objective 4486 - 0 iterations time 0.242, Presolve 0.00, Idiot 0.24\n",
      "\n",
      "Starting MIP optimization\n",
      "Cgl0004I processed model has 3471 rows, 1486 columns (1486 integer (1486 of which binary)) and 8424 elements\n",
      "Coin3009W Conflict graph built in 0.000 seconds, density: 0.112%\n",
      "Cgl0015I Clique Strengthening extended 0 cliques, 0 were dominated\n",
      "Cbc0045I Nauty: 6629 orbits (108 useful covering 298 variables), 95 generators, group size: 2.90748e+43 - sparse size 39224 - took 0.107119 seconds\n",
      "Cbc0038I Initial state - 1486 integers unsatisfied sum - 148.6\n",
      "Cbc0038I Pass   1: (0.17 seconds) suminf.    6.49641 (183) obj. -4046.15 iterations 1088\n",
      "Cbc0038I Pass   2: (0.17 seconds) suminf.    6.49641 (183) obj. -4046.15 iterations 44\n",
      "Cbc0038I Pass   3: (0.18 seconds) suminf.    4.67234 (12) obj. -3740.11 iterations 116\n",
      "Cbc0038I Solution found of -3527\n",
      "Cbc0038I Rounding solution of -3631 is better than previous of -3527\n",
      "\n",
      "Cbc0038I Before mini branch and bound, 0 integers at bound fixed and 0 continuous\n",
      "Cbc0038I Full problem 3471 rows 1486 columns, reduced to 3471 rows 1486 columns - 207 fixed gives 736, 616 - ok now\n",
      "Cbc0038I Full problem 3471 rows 1486 columns, reduced to 736 rows 616 columns\n",
      "Cbc0038I Mini branch and bound improved solution from -3631 to -4299 (0.25 seconds)\n",
      "Cbc0038I Round again with cutoff of -4318.6\n",
      "Cbc0038I Pass   4: (0.27 seconds) suminf.   22.61572 (282) obj. -4318.6 iterations 216\n",
      "Cbc0038I Pass   5: (0.27 seconds) suminf.   22.61572 (282) obj. -4318.6 iterations 33\n",
      "Cbc0038I Pass   6: (0.27 seconds) suminf.   22.35167 (312) obj. -4318.6 iterations 48\n",
      "Cbc0038I Pass   7: (0.28 seconds) suminf.   22.28249 (334) obj. -4318.6 iterations 16\n",
      "Cbc0038I Pass   8: (0.28 seconds) suminf.   23.53393 (300) obj. -4318.6 iterations 27\n",
      "Cbc0038I Pass   9: (0.29 seconds) suminf.   23.53004 (322) obj. -4318.6 iterations 4\n",
      "Cbc0038I Pass  10: (0.30 seconds) suminf.    1.48718 (5) obj. -4335.99 iterations 456\n",
      "Cbc0038I Pass  11: (0.31 seconds) suminf.    1.48718 (5) obj. -4335.99 iterations 14\n",
      "Cbc0038I Pass  12: (0.31 seconds) suminf.    0.75882 (3) obj. -4318.6 iterations 97\n",
      "Cbc0038I Pass  13: (0.31 seconds) suminf.    0.89231 (3) obj. -4333.89 iterations 37\n",
      "Cbc0038I Pass  14: (0.33 seconds) suminf.    0.13336 (1) obj. -4318.6 iterations 503\n",
      "Cbc0038I Solution found of -4319\n",
      "Cbc0038I Rounding solution of -4444 is better than previous of -4319\n",
      "\n",
      "Cbc0038I Before mini branch and bound, 0 integers at bound fixed and 0 continuous\n",
      "Cbc0038I Full problem 3471 rows 1486 columns, reduced to 3471 rows 1486 columns - 572 fixed gives 200, 131 - ok now\n",
      "Cbc0038I Full problem 3471 rows 1486 columns, reduced to 104 rows 80 columns\n",
      "Cbc0038I Mini branch and bound improved solution from -4444 to -4451 (0.35 seconds)\n",
      "Cbc0038I Round again with cutoff of -4458.8\n",
      "Cbc0038I Pass  15: (0.37 seconds) suminf.   26.45450 (114) obj. -4458.8 iterations 372\n",
      "Cbc0038I Pass  16: (0.40 seconds) suminf.    0.20008 (1) obj. -4458.8 iterations 1174\n",
      "Cbc0038I Solution found of -4459\n",
      "Cbc0038I Rounding solution of -4469 is better than previous of -4459\n",
      "\n",
      "Cbc0038I Before mini branch and bound, 0 integers at bound fixed and 0 continuous\n",
      "Cbc0038I Full problem 3471 rows 1486 columns, reduced to 3471 rows 1486 columns - 563 fixed gives 243, 152 - ok now\n",
      "Cbc0038I Full problem 3471 rows 1486 columns, reduced to 120 rows 89 columns\n",
      "Cbc0038I Mini branch and bound improved solution from -4469 to -4474 (0.42 seconds)\n",
      "Cbc0038I Round again with cutoff of -4478.3\n",
      "Cbc0038I Pass  17: (0.44 seconds) suminf.   35.81844 (149) obj. -4478.3 iterations 214\n",
      "Cbc0038I Pass  18: (0.49 seconds) suminf.    2.87491 (10) obj. -4478.3 iterations 1586\n",
      "Cbc0038I Solution found of -4484\n",
      "Cbc0038I Before mini branch and bound, 0 integers at bound fixed and 0 continuous\n",
      "Cbc0038I Full problem 3471 rows 1486 columns, reduced to 3471 rows 1486 columns - 559 fixed gives 262, 167 - ok now\n",
      "Cbc0038I Full problem 3471 rows 1486 columns, reduced to 50 rows 54 columns\n",
      "Cbc0038I Mini branch and bound did not improve solution (0.51 seconds)\n",
      "Cbc0038I Round again with cutoff of -4485.3\n",
      "Cbc0038I Pass  19: (0.52 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 112\n",
      "Cbc0038I Pass  20: (0.53 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 1\n",
      "Cbc0038I Pass  21: (0.53 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 21\n",
      "Cbc0038I Pass  22: (0.54 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 29\n",
      "Cbc0038I Pass  23: (0.55 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 30\n",
      "Cbc0038I Pass  24: (0.55 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 32\n",
      "Cbc0038I Pass  25: (0.56 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 31\n",
      "Cbc0038I Pass  26: (0.57 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 23\n",
      "Cbc0038I Pass  27: (0.57 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 24\n",
      "Cbc0038I Pass  28: (0.58 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 31\n",
      "Cbc0038I Pass  29: (0.59 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 34\n",
      "Cbc0038I Pass  30: (0.59 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 29\n",
      "Cbc0038I Pass  31: (0.60 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 31\n",
      "Cbc0038I Pass  32: (0.61 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 33\n",
      "Cbc0038I Pass  33: (0.61 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 31\n",
      "Cbc0038I Pass  34: (0.62 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 29\n",
      "Cbc0038I Pass  35: (0.63 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 28\n",
      "Cbc0038I Pass  36: (0.63 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 34\n",
      "Cbc0038I Pass  37: (0.64 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 26\n",
      "Cbc0038I Pass  38: (0.65 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 28\n",
      "Cbc0038I Pass  39: (0.65 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 29\n",
      "Cbc0038I Pass  40: (0.66 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 28\n",
      "Cbc0038I Pass  41: (0.67 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 32\n",
      "Cbc0038I Pass  42: (0.67 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 27\n",
      "Cbc0038I Pass  43: (0.68 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 27\n",
      "Cbc0038I Pass  44: (0.68 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 28\n",
      "Cbc0038I Pass  45: (0.69 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 30\n",
      "Cbc0038I Pass  46: (0.70 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 24\n",
      "Cbc0038I Pass  47: (0.70 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 25\n",
      "Cbc0038I Pass  48: (0.71 seconds) suminf.  136.96083 (1486) obj. -4485.3 iterations 29\n",
      "Cbc0038I No solution found this major pass\n",
      "Cbc0038I Before mini branch and bound, 0 integers at bound fixed and 0 continuous\n",
      "Cbc0038I Full problem 3471 rows 1486 columns, reduced to 3471 rows 1486 columns - 559 fixed gives 262, 167 - ok now\n",
      "Cbc0038I Mini branch and bound did not improve solution (0.73 seconds)\n",
      "Cbc0038I After 0.73 seconds - Feasibility pump exiting with objective of -4484 - took 0.60 seconds\n",
      "Cbc0012I Integer solution of -4484 found by feasibility pump after 0 iterations and 0 nodes (0.73 seconds)\n",
      "Cbc0030I Thread 0 used 0 times,  waiting to start 0.026098967, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 1 used 0 times,  waiting to start 0.037034273, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 2 used 0 times,  waiting to start 0.035430908, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 3 used 0 times,  waiting to start 0.033810377, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 4 used 0 times,  waiting to start 0.032182455, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 5 used 0 times,  waiting to start 0.030574799, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 6 used 0 times,  waiting to start 0.02895999, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 7 used 0 times,  waiting to start 0.027349949, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 8 used 0 times,  waiting to start 0.025723457, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 9 used 0 times,  waiting to start 0.024122715, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 10 used 0 times,  waiting to start 0.022567272, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 11 used 0 times,  waiting to start 0.021045208, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 12 used 0 times,  waiting to start 0.019520044, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 13 used 0 times,  waiting to start 0.017935038, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 14 used 0 times,  waiting to start 0.016403675, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 15 used 0 times,  waiting to start 0.014829397, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Main thread 0 waiting for threads,  1 locks, 1.4305115e-06 locked, 4.7683716e-07 waiting for locks\n",
      "Cbc0011I Exiting as integer gap of 2 less than 1e-10 or 30%\n",
      "Cbc0001I Search completed - best objective -4484, took 0 iterations and 0 nodes (0.77 seconds)\n",
      "Cbc0035I Maximum depth 0, 0 variables fixed on reduced cost\n",
      "Total time (CPU seconds):       0.78   (Wallclock seconds):       0.79\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model = min_vertex_k_cut.train_test_split_connected_graph(coarsed_main_component, train_min_fraq=0.5, test_min_fraq=0.1, max_mip_gap=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Molecules in train: 3964\n",
      "Molecules in test: 520\n",
      "Molecules lost: 2\n"
     ]
    }
   ],
   "source": [
    "split = min_vertex_k_cut.process_bisect_results(model, coarsed_main_component, main_component, node_to_cluster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_idx = []\n",
    "second_idx = []\n",
    "\n",
    "for S_idx, partition in enumerate(split):\n",
    "    G_idx = new_nodes_to_old[S_idx]\n",
    "    if partition == 0:\n",
    "        first_idx.append(G_idx)\n",
    "    if partition == 1:\n",
    "        second_idx.append(G_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "for component in small_components:\n",
    "    i = np.argmin([len(first_idx), len(second_idx)])\n",
    "    if i == 0:\n",
    "        first_idx.extend(component)\n",
    "    if i == 1:\n",
    "        second_idx.extend(component)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7847\n",
      "7847\n"
     ]
    }
   ],
   "source": [
    "print(len(first_idx))\n",
    "print(len(second_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "part_first = train.iloc[first_idx]\n",
    "part_second = train.iloc[second_idx]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hi split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import average_precision_score\n",
    "\n",
    "def run_gb_gridsearch(train_fps, val_fps, train_y, val_y):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(val_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + val_fps\n",
    "    y = train_y + val_y\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    gb = GradientBoostingClassifier()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(gb, params, cv=pds, n_iter=30, refit=False, scoring='average_precision', verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    gb = GradientBoostingClassifier(**best_params)\n",
    "    gb.fit(train_fps, train_y)\n",
    "\n",
    "    val_preds = gb.predict_proba(val_fps)[:, 1]\n",
    "    val_metrics = average_precision_score(val_y, val_preds)\n",
    "    return val_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[14:04:02] WARNING: not removing hydrogen atom without neighbors\n",
      "[14:04:02] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=250, subsample=0.9;, score=0.057 total time=  54.7s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=3, n_estimators=500, subsample=0.7;, score=0.052 total time= 1.4min\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=50, subsample=0.9;, score=0.032 total time=   6.9s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=200, subsample=0.9;, score=0.040 total time=   8.5s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=5, n_estimators=150, subsample=1.0;, score=0.046 total time=  28.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=150, subsample=0.9;, score=0.031 total time=   7.2s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=500, subsample=0.7;, score=0.044 total time=  57.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=500, subsample=0.7;, score=0.023 total time= 1.4min\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=100, subsample=0.4;, score=0.033 total time=   7.0s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=150, subsample=1.0;, score=0.038 total time=  38.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=10, subsample=0.9;, score=0.040 total time=   6.5s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=10, subsample=0.4;, score=0.036 total time=   6.4s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=50, subsample=0.4;, score=0.050 total time=   6.6s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=100, subsample=1.0;, score=0.050 total time=   7.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=150, subsample=1.0;, score=0.064 total time=   7.7s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=150, subsample=0.7;, score=0.036 total time=  29.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=50, subsample=0.9;, score=0.036 total time=   6.7s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=250, subsample=0.9;, score=0.072 total time=   9.0s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=3, n_estimators=100, subsample=1.0;, score=0.028 total time=  35.2s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=100, subsample=0.7;, score=0.045 total time=  16.5s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=250, subsample=0.9;, score=0.055 total time=   7.8s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=50, subsample=0.9;, score=0.034 total time=  19.1s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=500, subsample=1.0;, score=0.035 total time= 2.5min\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=50, subsample=0.9;, score=0.047 total time=   6.9s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=50, subsample=1.0;, score=0.060 total time=   6.7s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=50, subsample=0.9;, score=0.060 total time=  13.0s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=10, subsample=0.9;, score=0.027 total time=   6.5s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=50, subsample=0.4;, score=0.060 total time=   6.5s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=250, subsample=0.4;, score=0.022 total time=   7.7s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=10, subsample=0.7;, score=0.033 total time=   7.9s\n",
      "{'subsample': 0.9, 'n_estimators': 250, 'min_samples_split': 5, 'min_samples_leaf': 3, 'max_features': 'sqrt', 'max_depth': 4, 'learning_rate': 0.01}\n",
      "0.06877828975491546\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in part_first['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "val_mols = [Chem.MolFromSmiles(x) for x in part_second['smiles']]\n",
    "val_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in val_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_morgan_fps, val_morgan_fps, part_first['value'].to_list(), part_second['value'].to_list())\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[14:20:25] WARNING: not removing hydrogen atom without neighbors\n",
      "[14:20:25] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.08437590235381434\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "gb = GradientBoostingClassifier(\n",
    "    subsample=0.9,\n",
    "    n_estimators=250,\n",
    "    min_samples_split=5,\n",
    "    min_samples_leaf=3,\n",
    "    max_features='sqrt',\n",
    "    max_depth=4,\n",
    "    learning_rate=0.01\n",
    ")\n",
    "\n",
    "gb.fit(train_morgan_fps, train['value'].to_list())\n",
    "test_preds = gb.predict_proba(test_morgan_fps)[:, 1]\n",
    "test_metrics = average_precision_score(test['value'], test_preds)\n",
    "print(test_metrics)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scaffold split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rdkit.ML.Cluster import Butina\n",
    "from numpy.random import default_rng\n",
    "\n",
    "\n",
    "def butina_split(smiles: list[str], cutoff: float, seed: int, frac_train=0.8):\n",
    "    \"\"\"\n",
    "    Select distinct molecules to train/test. Returns indices of the molecules in the smiles list.\n",
    "    Adapted from DeepChem (https://deepchem.io/), but random seed is added.\n",
    "    \"\"\"\n",
    "\n",
    "    mols = [Chem.MolFromSmiles(smile) for smile in smiles]\n",
    "    fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in mols]\n",
    "\n",
    "    dists = []\n",
    "    nfps = len(fps)\n",
    "    for i in range(1, nfps):\n",
    "        sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])\n",
    "        dists.extend([1 - x for x in sims])\n",
    "    scaffold_sets = Butina.ClusterData(dists, nfps, cutoff, isDistData=True)\n",
    "    scaffold_sets = sorted(scaffold_sets, key=lambda x: -len(x))\n",
    "\n",
    "    rng = default_rng(seed)\n",
    "    rng.shuffle(scaffold_sets)\n",
    "\n",
    "    train_cutoff = frac_train * len(smiles)\n",
    "    train_inds = []\n",
    "    test_inds = []\n",
    "\n",
    "    for scaffold_set in scaffold_sets:\n",
    "        if len(train_inds) + len(scaffold_set) > train_cutoff:\n",
    "            test_inds += scaffold_set\n",
    "        else:\n",
    "            train_inds += scaffold_set\n",
    "    return train_inds, test_inds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[14:24:31] WARNING: not removing hydrogen atom without neighbors\n",
      "[14:24:31] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    }
   ],
   "source": [
    "train_idx, val_idx = butina_split(train['smiles'].to_list(), cutoff=0.5, seed=123, frac_train=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "part_first = train.iloc[train_idx]\n",
    "part_second = train.iloc[val_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import average_precision_score\n",
    "\n",
    "def run_gb_gridsearch(train_fps, val_fps, train_y, val_y):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(val_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + val_fps\n",
    "    y = train_y + val_y\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    gb = GradientBoostingClassifier()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(gb, params, cv=pds, n_iter=30, refit=False, scoring='average_precision', verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    gb = GradientBoostingClassifier(**best_params)\n",
    "    gb.fit(train_fps, train_y)\n",
    "\n",
    "    val_preds = gb.predict_proba(val_fps)[:, 1]\n",
    "    val_metrics = average_precision_score(val_y, val_preds)\n",
    "    return val_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[14:27:16] WARNING: not removing hydrogen atom without neighbors\n",
      "[14:27:16] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=200, subsample=0.9;, score=0.100 total time=   7.4s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=10, subsample=0.9;, score=0.078 total time=   9.0s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=250, subsample=1.0;, score=0.172 total time=   9.1s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=7, n_estimators=200, subsample=0.4;, score=0.034 total time=   7.6s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=250, subsample=0.4;, score=0.044 total time=  31.4s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=500, subsample=0.9;, score=0.178 total time= 2.3min\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=10, subsample=0.7;, score=0.106 total time=   6.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=200, subsample=0.4;, score=0.148 total time=  16.7s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=100, subsample=0.4;, score=0.101 total time=  11.5s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=50, subsample=0.7;, score=0.118 total time=   6.6s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=250, subsample=0.9;, score=0.049 total time=   9.1s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=200, subsample=1.0;, score=0.086 total time= 1.1min\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=100, subsample=0.4;, score=0.062 total time=   6.9s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=250, subsample=0.9;, score=0.049 total time=   8.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=500, subsample=0.9;, score=0.030 total time=  11.6s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=500, subsample=0.4;, score=0.153 total time=   9.2s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=500, subsample=0.9;, score=0.149 total time= 1.2min\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=150, subsample=0.4;, score=0.057 total time=   7.0s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=150, subsample=0.9;, score=0.113 total time=  26.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=2, n_estimators=500, subsample=1.0;, score=0.093 total time= 1.9min\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=150, subsample=0.9;, score=0.089 total time=  26.2s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=250, subsample=0.9;, score=0.156 total time=   7.8s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=100, subsample=0.9;, score=0.210 total time=   7.4s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=500, subsample=1.0;, score=0.193 total time= 2.5min\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=50, subsample=0.7;, score=0.045 total time=   6.8s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=200, subsample=1.0;, score=0.062 total time= 1.1min\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=150, subsample=0.7;, score=0.060 total time=  29.3s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=200, subsample=0.7;, score=0.105 total time=   8.2s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=10, subsample=0.9;, score=0.164 total time=   6.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=50, subsample=0.4;, score=0.157 total time=  11.4s\n",
      "{'subsample': 0.9, 'n_estimators': 100, 'min_samples_split': 7, 'min_samples_leaf': 5, 'max_features': 'sqrt', 'max_depth': 4, 'learning_rate': 0.01}\n",
      "0.18896757138989306\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in part_first['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "val_mols = [Chem.MolFromSmiles(x) for x in part_second['smiles']]\n",
    "val_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in val_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_morgan_fps, val_morgan_fps, part_first['value'].to_list(), part_second['value'].to_list())\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[16:10:23] WARNING: not removing hydrogen atom without neighbors\n",
      "[16:10:23] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.07843671137825257\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "gb = GradientBoostingClassifier(\n",
    "    subsample=0.9,\n",
    "    n_estimators=100,\n",
    "    min_samples_split=7,\n",
    "    min_samples_leaf=5,\n",
    "    max_features='sqrt',\n",
    "    max_depth=4,\n",
    "    learning_rate=0.01\n",
    ")\n",
    "\n",
    "gb.fit(train_morgan_fps, train['value'].to_list())\n",
    "test_preds = gb.predict_proba(test_morgan_fps)[:, 1]\n",
    "test_metrics = average_precision_score(test['value'], test_preds)\n",
    "print(test_metrics)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
