{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem\n",
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit\n",
    "import networkx as nx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../code')\n",
    "\n",
    "from metrics import get_hi_metrics\n",
    "import min_vertex_k_cut"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>383</th>\n",
       "      <td>CC(C)Oc1ccccc1N1CCN(Cc2cccc(C(=O)N3CCCCC3)c2)CC1</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>386</th>\n",
       "      <td>CC(C)Oc1ccccc1N1CCN(Cc2cccc(CN3CCCCC3=O)c2)CC1</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>389</th>\n",
       "      <td>CC(C)Oc1ccccc1N1CCN(Cc2ccccc2CN2CCCCC2=O)CC1</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2695</th>\n",
       "      <td>COc1ccccc1N1CCN(CC2COCC(c3ccccc3)(c3ccccc3)O2)CC1</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2995</th>\n",
       "      <td>COc1ccccc1N1CCN(C[C@H]2OCCOC2(c2ccccc2)c2ccccc...</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5444</th>\n",
       "      <td>O=C1c2ccccc2C(=O)N1CCCCN1CCCN(C(c2ccccc2)c2ccc...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4391</th>\n",
       "      <td>O=C(CCC(=O)c1ccccc1)NCCc1c[nH]c2ccccc12</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4397</th>\n",
       "      <td>O=C(CCCC(=O)c1ccccc1)NCCc1c[nH]c2ccccc12</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5999</th>\n",
       "      <td>OC12C3C4CC5C6C4C1C6C(C53)N2CC1CCCCC1</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5458</th>\n",
       "      <td>O=C1c2ccccc2CCCN1CCN1CCC(n2c(O)nc3cc(F)ccc32)CC1</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2385 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 smiles  value\n",
       "383    CC(C)Oc1ccccc1N1CCN(Cc2cccc(C(=O)N3CCCCC3)c2)CC1   True\n",
       "386      CC(C)Oc1ccccc1N1CCN(Cc2cccc(CN3CCCCC3=O)c2)CC1   True\n",
       "389        CC(C)Oc1ccccc1N1CCN(Cc2ccccc2CN2CCCCC2=O)CC1   True\n",
       "2695  COc1ccccc1N1CCN(CC2COCC(c3ccccc3)(c3ccccc3)O2)CC1   True\n",
       "2995  COc1ccccc1N1CCN(C[C@H]2OCCOC2(c2ccccc2)c2ccccc...  False\n",
       "...                                                 ...    ...\n",
       "5444  O=C1c2ccccc2C(=O)N1CCCCN1CCCN(C(c2ccccc2)c2ccc...   True\n",
       "4391            O=C(CCC(=O)c1ccccc1)NCCc1c[nH]c2ccccc12  False\n",
       "4397           O=C(CCCC(=O)c1ccccc1)NCCc1c[nH]c2ccccc12  False\n",
       "5999               OC12C3C4CC5C6C4C1C6C(C53)N2CC1CCCCC1  False\n",
       "5458   O=C1c2ccccc2CCCN1CCN1CCC(n2c(O)nc3cc(F)ccc32)CC1   True\n",
       "\n",
       "[2385 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../data/hi/drd2/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../data/hi/drd2/test_1.csv', index_col=0)\n",
    "\n",
    "train"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Split train into train and val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "smiles = train['smiles'].to_list()\n",
    "threshold = 0.4\n",
    "\n",
    "neighborhood_graph = min_vertex_k_cut.get_neighborhood_graph(smiles, threshold)\n",
    "main_component, small_components = min_vertex_k_cut.get_main_component(neighborhood_graph)\n",
    "\n",
    "old_nodes_to_new = dict(zip(main_component.nodes(), range(main_component.number_of_nodes())))\n",
    "new_nodes_to_old = {v: k for k, v in old_nodes_to_new.items()}\n",
    "main_component = nx.relabel_nodes(main_component, old_nodes_to_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "coarsed_main_component, node_to_cluster = min_vertex_k_cut.coarse_graph(main_component, 0.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total molecules: 1193\n",
      "Min train size 596\n",
      "Min test size 119\n",
      "Welcome to the CBC MILP Solver \n",
      "Version: Trunk\n",
      "Build Date: Oct 24 2021 \n",
      "\n",
      "Starting solution of the Linear programming relaxation problem using Primal Simplex\n",
      "\n",
      "Coin0506I Presolve 391 (-178) rows, 178 (0) columns and 956 (-356) elements\n",
      "Clp1000I sum of infeasibilities 1.71113e-07 - average 4.37629e-10, 0 fixed columns\n",
      "Coin0506I Presolve 391 (0) rows, 178 (0) columns and 956 (0) elements\n",
      "Clp0029I End of values pass after 178 iterations\n",
      "Clp0014I Perturbing problem by 0.001% of 0.67057672 - largest nonzero change 2.989373e-05 ( 0.00242536%) - largest zero change 0\n",
      "Clp0000I Optimal - objective value 1193\n",
      "Clp0000I Optimal - objective value 1193\n",
      "Clp0000I Optimal - objective value 1193\n",
      "Coin0511I After Postsolve, objective 1193, infeasibilities - dual 0 (0), primal 0 (0)\n",
      "Clp0032I Optimal objective 1193 - 0 iterations time 0.022, Presolve 0.00, Idiot 0.02\n",
      "\n",
      "Starting MIP optimization\n",
      "Cgl0003I 5 fixed, 0 tightened bounds, 37 strengthened rows, 0 substitutions\n",
      "Cgl0003I 0 fixed, 0 tightened bounds, 24 strengthened rows, 0 substitutions\n",
      "Cgl0003I 0 fixed, 0 tightened bounds, 18 strengthened rows, 0 substitutions\n",
      "Cgl0003I 0 fixed, 0 tightened bounds, 5 strengthened rows, 0 substitutions\n",
      "Cgl0003I 0 fixed, 0 tightened bounds, 1 strengthened rows, 0 substitutions\n",
      "Cgl0004I processed model has 300 rows, 173 columns (173 integer (173 of which binary)) and 809 elements\n",
      "Coin3009W Conflict graph built in 0.000 seconds, density: 0.858%\n",
      "Cgl0015I Clique Strengthening extended 0 cliques, 0 were dominated\n",
      "Cbc0045I Nauty: 718 orbits (12 useful covering 28 variables), 8 generators, group size: 768 - sparse size 4120 - took 0.000639 seconds\n",
      "Cbc0038I Initial state - 7 integers unsatisfied sum - 1.35882\n",
      "Cbc0038I Pass   1: suminf.    0.60000 (22) obj. -1172.72 iterations 128\n",
      "Cbc0038I Solution found of -1020\n",
      "Cbc0038I Before mini branch and bound, 144 integers at bound fixed and 0 continuous\n",
      "Cbc0038I Full problem 300 rows 173 columns, reduced to 19 rows 15 columns\n",
      "Cbc0038I Mini branch and bound improved solution from -1020 to -1172 (0.02 seconds)\n",
      "Cbc0038I Round again with cutoff of -1173.3\n",
      "Cbc0038I Reduced cost fixing fixed 14 variables on major pass 2\n",
      "Cbc0038I Pass   2: suminf.    0.67402 (17) obj. -1173.3 iterations 29\n",
      "Cbc0038I Pass   3: suminf.    1.00965 (7) obj. -1173.3 iterations 75\n",
      "Cbc0038I Pass   4: suminf.    0.80013 (16) obj. -1173.3 iterations 45\n",
      "Cbc0038I Pass   5: suminf.    1.00965 (7) obj. -1173.3 iterations 96\n",
      "Cbc0038I Pass   6: suminf.    1.24403 (7) obj. -1173.3 iterations 13\n",
      "Cbc0038I Pass   7: suminf.    1.24403 (7) obj. -1173.3 iterations 14\n",
      "Cbc0038I Pass   8: suminf.    2.37601 (7) obj. -1173.3 iterations 30\n",
      "Cbc0038I Pass   9: suminf.    0.80622 (17) obj. -1173.3 iterations 69\n",
      "Cbc0038I Pass  10: suminf.    1.00965 (7) obj. -1173.3 iterations 104\n",
      "Cbc0038I Pass  11: suminf.    0.80013 (16) obj. -1173.3 iterations 50\n",
      "Cbc0038I Pass  12: suminf.    1.00965 (7) obj. -1173.3 iterations 94\n",
      "Cbc0038I Pass  13: suminf.    1.00965 (7) obj. -1173.3 iterations 11\n",
      "Cbc0038I Pass  14: suminf.    1.24403 (7) obj. -1173.3 iterations 8\n",
      "Cbc0038I Pass  15: suminf.    1.24403 (7) obj. -1173.3 iterations 2\n",
      "Cbc0038I Pass  16: suminf.    2.37601 (7) obj. -1173.3 iterations 35\n",
      "Cbc0038I Pass  17: suminf.    1.24403 (7) obj. -1173.3 iterations 35\n",
      "Cbc0038I Pass  18: suminf.    1.24403 (7) obj. -1173.3 iterations 7\n",
      "Cbc0038I Pass  19: suminf.    1.82399 (7) obj. -1173.3 iterations 38\n",
      "Cbc0038I Pass  20: suminf.    1.82399 (7) obj. -1173.3 iterations 0\n",
      "Cbc0038I Pass  21: suminf.    0.67402 (17) obj. -1173.3 iterations 69\n",
      "Cbc0038I Pass  22: suminf.    1.00965 (7) obj. -1173.3 iterations 91\n",
      "Cbc0038I Pass  23: suminf.    0.80013 (16) obj. -1173.3 iterations 47\n",
      "Cbc0038I Pass  24: suminf.    1.00965 (7) obj. -1173.3 iterations 92\n",
      "Cbc0038I Pass  25: suminf.    1.00965 (7) obj. -1173.3 iterations 14\n",
      "Cbc0038I Pass  26: suminf.    1.00965 (7) obj. -1173.3 iterations 15\n",
      "Cbc0038I Pass  27: suminf.    1.24403 (7) obj. -1173.3 iterations 14\n",
      "Cbc0038I Pass  28: suminf.    1.24403 (7) obj. -1173.3 iterations 6\n",
      "Cbc0038I Pass  29: suminf.    2.37601 (7) obj. -1173.3 iterations 36\n",
      "Cbc0038I Pass  30: suminf.    0.80622 (17) obj. -1173.3 iterations 87\n",
      "Cbc0038I Pass  31: suminf.    1.00965 (7) obj. -1173.3 iterations 96\n",
      "Cbc0038I No solution found this major pass\n",
      "Cbc0038I Before mini branch and bound, 142 integers at bound fixed and 0 continuous\n",
      "Cbc0038I Full problem 300 rows 173 columns, reduced to 12 rows 11 columns\n",
      "Cbc0038I Mini branch and bound did not improve solution (0.04 seconds)\n",
      "Cbc0038I After 0.04 seconds - Feasibility pump exiting with objective of -1172 - took 0.03 seconds\n",
      "Cbc0012I Integer solution of -1172 found by feasibility pump after 0 iterations and 0 nodes (0.04 seconds)\n",
      "Cbc0030I Thread 0 used 0 times,  waiting to start 0.0037388802, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 1 used 0 times,  waiting to start 0.0038971901, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 2 used 0 times,  waiting to start 0.0036940575, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 3 used 0 times,  waiting to start 0.003477335, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 4 used 0 times,  waiting to start 0.0032522678, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 5 used 0 times,  waiting to start 0.0030543804, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 6 used 0 times,  waiting to start 0.0028371811, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 7 used 0 times,  waiting to start 0.0026264191, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 8 used 0 times,  waiting to start 0.002414465, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 9 used 0 times,  waiting to start 0.0022034645, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 10 used 0 times,  waiting to start 0.0019943714, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 11 used 0 times,  waiting to start 0.0017793179, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 12 used 0 times,  waiting to start 0.0015654564, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 13 used 0 times,  waiting to start 0.0013504028, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 14 used 0 times,  waiting to start 0.0011396408, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Thread 15 used 0 times,  waiting to start 0.00092983246, 0 cpu time, 0 locks, 0 locked, 0 waiting for locks\n",
      "Cbc0030I Main thread 0 waiting for threads,  1 locks, 1.1920929e-06 locked, 2.3841858e-07 waiting for locks\n",
      "Cbc0011I Exiting as integer gap of 4.0294118 less than 1e-10 or 30%\n",
      "Cbc0001I Search completed - best objective -1172, took 0 iterations and 0 nodes (0.05 seconds)\n",
      "Cbc0035I Maximum depth 0, 0 variables fixed on reduced cost\n",
      "Total time (CPU seconds):       0.05   (Wallclock seconds):       0.05\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model = min_vertex_k_cut.train_test_split_connected_graph(coarsed_main_component, train_min_fraq=0.5, test_min_fraq=0.1, max_mip_gap=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Molecules in train: 1039\n",
      "Molecules in test: 133\n",
      "Molecules lost: 21\n"
     ]
    }
   ],
   "source": [
    "split = min_vertex_k_cut.process_bisect_results(model, coarsed_main_component, main_component, node_to_cluster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_idx = []\n",
    "second_idx = []\n",
    "\n",
    "for S_idx, partition in enumerate(split):\n",
    "    G_idx = new_nodes_to_old[S_idx]\n",
    "    if partition == 0:\n",
    "        first_idx.append(G_idx)\n",
    "    if partition == 1:\n",
    "        second_idx.append(G_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "for component in small_components:\n",
    "    i = np.argmin([len(first_idx), len(second_idx)])\n",
    "    if i == 0:\n",
    "        first_idx.extend(component)\n",
    "    if i == 1:\n",
    "        second_idx.extend(component)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1180\n",
      "1184\n"
     ]
    }
   ],
   "source": [
    "print(len(first_idx))\n",
    "print(len(second_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "part_first = train.iloc[first_idx]\n",
    "part_second = train.iloc[second_idx]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hi split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import average_precision_score\n",
    "\n",
    "def run_gb_gridsearch(train_fps, val_fps, train_y, val_y):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(val_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + val_fps\n",
    "    y = train_y + val_y\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    gb = GradientBoostingClassifier()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(gb, params, cv=pds, n_iter=30, refit=False, scoring='average_precision', verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    gb = GradientBoostingClassifier(**best_params)\n",
    "    gb.fit(train_fps, train_y)\n",
    "\n",
    "    val_preds = gb.predict_proba(val_fps)[:, 1]\n",
    "    val_metrics = average_precision_score(val_y, val_preds)\n",
    "    return val_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=50, subsample=0.4;, score=0.669 total time=   1.3s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=100, subsample=0.4;, score=0.625 total time=   1.4s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=200, subsample=1.0;, score=0.648 total time=   4.8s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=250, subsample=1.0;, score=0.669 total time=   1.5s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=10, subsample=0.7;, score=0.666 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=250, subsample=0.4;, score=0.610 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=50, subsample=1.0;, score=0.656 total time=   2.3s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=5, n_estimators=50, subsample=1.0;, score=0.677 total time=   2.4s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=10, subsample=1.0;, score=0.652 total time=   1.0s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=150, subsample=0.4;, score=0.683 total time=   2.6s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=100, subsample=0.9;, score=0.664 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=200, subsample=1.0;, score=0.648 total time=   8.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=50, subsample=0.7;, score=0.688 total time=   1.7s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=10, subsample=0.7;, score=0.661 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=100, subsample=0.4;, score=0.654 total time=   2.6s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=100, subsample=1.0;, score=0.655 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=5, n_estimators=250, subsample=1.0;, score=0.676 total time=   9.7s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=150, subsample=0.9;, score=0.673 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=150, subsample=0.9;, score=0.672 total time=   1.3s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=100, subsample=0.4;, score=0.662 total time=   2.1s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=250, subsample=0.7;, score=0.628 total time=   1.4s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=500, subsample=0.9;, score=0.619 total time=   9.7s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=5, n_estimators=250, subsample=0.4;, score=0.649 total time=   4.7s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=250, subsample=1.0;, score=0.671 total time=   1.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=3, n_estimators=150, subsample=0.7;, score=0.715 total time=   3.0s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=10, subsample=0.4;, score=0.678 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=50, subsample=0.9;, score=0.673 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=150, subsample=0.4;, score=0.662 total time=   1.4s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=50, subsample=0.7;, score=0.584 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=2, n_estimators=10, subsample=0.4;, score=0.606 total time=   1.1s\n",
      "{'subsample': 0.7, 'n_estimators': 150, 'min_samples_split': 3, 'min_samples_leaf': 3, 'max_features': None, 'max_depth': 2, 'learning_rate': 0.5}\n",
      "0.6032513987142669\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in part_first['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "val_mols = [Chem.MolFromSmiles(x) for x in part_second['smiles']]\n",
    "val_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in val_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_morgan_fps, val_morgan_fps, part_first['value'].to_list(), part_second['value'].to_list())\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6770894792792055\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "gb = GradientBoostingClassifier(\n",
    "    subsample=0.7,\n",
    "    n_estimators=150,\n",
    "    min_samples_split=3,\n",
    "    min_samples_leaf=3,\n",
    "    max_features=None,\n",
    "    max_depth=2,\n",
    "    learning_rate=0.5\n",
    ")\n",
    "\n",
    "gb.fit(train_morgan_fps, train['value'].to_list())\n",
    "test_preds = gb.predict_proba(test_morgan_fps)[:, 1]\n",
    "test_metrics = average_precision_score(test['value'], test_preds)\n",
    "print(test_metrics)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scaffold split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rdkit.ML.Cluster import Butina\n",
    "from numpy.random import default_rng\n",
    "\n",
    "\n",
    "def butina_split(smiles: list[str], cutoff: float, seed: int, frac_train=0.8):\n",
    "    \"\"\"\n",
    "    Select distinct molecules to train/test. Returns indices of the molecules in the smiles list.\n",
    "    Adapted from DeepChem (https://deepchem.io/), but random seed is added.\n",
    "    \"\"\"\n",
    "\n",
    "    mols = [Chem.MolFromSmiles(smile) for smile in smiles]\n",
    "    fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in mols]\n",
    "\n",
    "    dists = []\n",
    "    nfps = len(fps)\n",
    "    for i in range(1, nfps):\n",
    "        sims = DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i])\n",
    "        dists.extend([1 - x for x in sims])\n",
    "    scaffold_sets = Butina.ClusterData(dists, nfps, cutoff, isDistData=True)\n",
    "    scaffold_sets = sorted(scaffold_sets, key=lambda x: -len(x))\n",
    "\n",
    "    rng = default_rng(seed)\n",
    "    rng.shuffle(scaffold_sets)\n",
    "\n",
    "    train_cutoff = frac_train * len(smiles)\n",
    "    train_inds = []\n",
    "    test_inds = []\n",
    "\n",
    "    for scaffold_set in scaffold_sets:\n",
    "        if len(train_inds) + len(scaffold_set) > train_cutoff:\n",
    "            test_inds += scaffold_set\n",
    "        else:\n",
    "            train_inds += scaffold_set\n",
    "    return train_inds, test_inds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_idx, val_idx = butina_split(train['smiles'].to_list(), cutoff=0.5, seed=123, frac_train=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "part_first = train.iloc[train_idx]\n",
    "part_second = train.iloc[val_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import average_precision_score\n",
    "\n",
    "def run_gb_gridsearch(train_fps, val_fps, train_y, val_y):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(val_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + val_fps\n",
    "    y = train_y + val_y\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    gb = GradientBoostingClassifier()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(gb, params, cv=pds, n_iter=30, refit=False, scoring='average_precision', verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    gb = GradientBoostingClassifier(**best_params)\n",
    "    gb.fit(train_fps, train_y)\n",
    "\n",
    "    val_preds = gb.predict_proba(val_fps)[:, 1]\n",
    "    val_metrics = average_precision_score(val_y, val_preds)\n",
    "    return val_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=200, subsample=1.0;, score=0.877 total time=   2.0s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=250, subsample=0.9;, score=0.859 total time=   5.4s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=10, subsample=0.9;, score=0.836 total time=   1.4s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=500, subsample=0.7;, score=0.831 total time=   1.8s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=2, n_estimators=100, subsample=0.9;, score=0.809 total time=   4.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=100, subsample=0.7;, score=0.834 total time=   3.7s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=100, subsample=0.7;, score=0.802 total time=   2.6s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=150, subsample=0.4;, score=0.837 total time=   2.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=100, subsample=0.9;, score=0.861 total time=   4.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=500, subsample=0.7;, score=0.775 total time=  14.0s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=150, subsample=0.9;, score=0.838 total time=   6.0s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=200, subsample=0.9;, score=0.837 total time=   7.6s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=50, subsample=0.7;, score=0.819 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=250, subsample=1.0;, score=0.851 total time=   1.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=150, subsample=1.0;, score=0.850 total time=   6.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=50, subsample=1.0;, score=0.824 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=500, subsample=0.7;, score=0.851 total time=   1.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=50, subsample=0.9;, score=0.792 total time=   2.7s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=200, subsample=0.4;, score=0.816 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=10, subsample=1.0;, score=0.809 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=50, subsample=0.4;, score=0.820 total time=   1.9s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=200, subsample=0.9;, score=0.849 total time=   7.8s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=150, subsample=0.7;, score=0.866 total time=   1.3s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=100, subsample=0.7;, score=0.851 total time=   3.7s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=200, subsample=0.4;, score=0.868 total time=   3.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=500, subsample=0.7;, score=0.751 total time=   2.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=250, subsample=0.4;, score=0.799 total time=   1.4s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=200, subsample=0.9;, score=0.819 total time=   4.7s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=200, subsample=0.9;, score=0.841 total time=   2.1s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=50, subsample=0.4;, score=0.857 total time=   2.1s\n",
      "{'subsample': 1.0, 'n_estimators': 200, 'min_samples_split': 3, 'min_samples_leaf': 5, 'max_features': 'sqrt', 'max_depth': 4, 'learning_rate': 0.01}\n",
      "0.8618394350472368\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in part_first['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "val_mols = [Chem.MolFromSmiles(x) for x in part_second['smiles']]\n",
    "val_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in val_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_morgan_fps, val_morgan_fps, part_first['value'].to_list(), part_second['value'].to_list())\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6629331875367845\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "gb = GradientBoostingClassifier(\n",
    "    subsample=1.0,\n",
    "    n_estimators=200,\n",
    "    min_samples_split=3,\n",
    "    min_samples_leaf=5,\n",
    "    max_features='sqrt',\n",
    "    max_depth=4,\n",
    "    learning_rate=0.01\n",
    ")\n",
    "\n",
    "gb.fit(train_morgan_fps, train['value'].to_list())\n",
    "test_preds = gb.predict_proba(test_morgan_fps)[:, 1]\n",
    "test_metrics = average_precision_score(test['value'], test_preds)\n",
    "print(test_metrics)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
