{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem\n",
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_hi_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>O=S(=O)(O)CCS(=O)(=O)O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>CC(C)CCS(=O)(=O)O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>90</th>\n",
       "      <td>O=S(=O)(O)CCO</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>106</th>\n",
       "      <td>O=S(=O)(O)CO</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>117</th>\n",
       "      <td>O=S(=O)(O)CCCCBr</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40932</th>\n",
       "      <td>COC(=O)c1cc2cc3c(c(O)c2c(=O)o1)OC1(Oc2c(O)c4c(...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40973</th>\n",
       "      <td>CCCCC1C(OCOc2ccccc2)COC(=O)N1C(C)c1ccccc1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41024</th>\n",
       "      <td>CC(C)=CC1CC(C)C2CCC(C)C3C(=O)C(O)=C(C)C(=O)C123</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41026</th>\n",
       "      <td>CCOC(=O)C12C(=O)C(C)CCC1C(C)CC2C=C(C)C</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41106</th>\n",
       "      <td>Cc1ccc(C=C2CN(C)CC3C(c4ccc(C)cc4)=C(C#N)C(=O)N...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>15696 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                  smiles  value\n",
       "4                                 O=S(=O)(O)CCS(=O)(=O)O      0\n",
       "21                                     CC(C)CCS(=O)(=O)O      0\n",
       "90                                         O=S(=O)(O)CCO      0\n",
       "106                                         O=S(=O)(O)CO      0\n",
       "117                                     O=S(=O)(O)CCCCBr      0\n",
       "...                                                  ...    ...\n",
       "40932  COC(=O)c1cc2cc3c(c(O)c2c(=O)o1)OC1(Oc2c(O)c4c(...      0\n",
       "40973          CCCCC1C(OCOc2ccccc2)COC(=O)N1C(C)c1ccccc1      0\n",
       "41024    CC(C)=CC1CC(C)C2CCC(C)C3C(=O)C(O)=C(C)C(=O)C123      0\n",
       "41026             CCOC(=O)C12C(=O)C(C)CCC1C(C)CC2C=C(C)C      0\n",
       "41106  Cc1ccc(C=C2CN(C)CC3C(c4ccc(C)cc4)=C(C#N)C(=O)N...      0\n",
       "\n",
       "[15696 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/hi/hiv/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/hi/hiv/test_1.csv', index_col=0)\n",
    "\n",
    "train"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_freqs = np.bincount(train['value'])\n",
    "class_weights = 1 / class_freqs\n",
    "y = train['value'].to_list() + test['value'].to_list()\n",
    "y = np.array(y)\n",
    "sample_weights = class_weights[y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_gb_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    knn = GradientBoostingClassifier()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(knn, params, cv=pds, n_iter=30, refit=False, scoring='average_precision', verbose=3)\n",
    "    grid_search.fit(X, y, sample_weight=sample_weights)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    knn = GradientBoostingClassifier(**best_params)\n",
    "    knn.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = knn.predict_proba(test_fps)[:, 1]\n",
    "    test_metrics = get_hi_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[09:41:41] WARNING: not removing hydrogen atom without neighbors\n",
      "[09:41:42] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=7, n_estimators=50, subsample=0.4;, score=0.067 total time=  10.1s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=10, subsample=0.7;, score=0.102 total time=  12.7s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=200, subsample=0.4;, score=0.072 total time=  11.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=200, subsample=0.4;, score=0.069 total time=  10.9s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=100, subsample=1.0;, score=0.094 total time=  10.6s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=250, subsample=1.0;, score=0.080 total time=  13.3s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=5, n_estimators=250, subsample=0.9;, score=0.083 total time= 1.3min\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=250, subsample=0.7;, score=0.065 total time= 1.9min\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=100, subsample=0.4;, score=0.129 total time=  27.0s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=500, subsample=1.0;, score=0.106 total time=  15.0s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=250, subsample=0.4;, score=0.109 total time=  53.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=100, subsample=0.7;, score=0.066 total time=  40.9s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=200, subsample=0.9;, score=0.073 total time=  11.6s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=100, subsample=0.9;, score=0.070 total time=  11.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=500, subsample=0.9;, score=0.077 total time= 4.6min\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=100, subsample=0.4;, score=0.063 total time=  10.4s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=150, subsample=0.7;, score=0.067 total time=  11.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=10, subsample=0.7;, score=0.061 total time=   9.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=50, subsample=1.0;, score=0.061 total time=  10.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=250, subsample=0.9;, score=0.069 total time=  12.2s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=10, subsample=0.7;, score=0.069 total time=   9.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=10, subsample=0.7;, score=0.074 total time=   9.5s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=150, subsample=1.0;, score=0.090 total time=  54.6s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=250, subsample=0.9;, score=0.084 total time=  12.4s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=100, subsample=0.4;, score=0.056 total time=  10.2s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=100, subsample=1.0;, score=0.074 total time=  54.3s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=100, subsample=0.4;, score=0.127 total time=  21.5s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=150, subsample=0.7;, score=0.071 total time=  12.0s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=200, subsample=0.7;, score=0.095 total time=  12.9s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=10, subsample=0.9;, score=0.059 total time=   9.8s\n",
      "{'subsample': 0.4, 'n_estimators': 100, 'min_samples_split': 7, 'min_samples_leaf': 3, 'max_features': None, 'max_depth': 3, 'learning_rate': 0.3}\n",
      "{'roc_auc': 0.6322350120279896, 'bedroc': 0.08854344792710367, 'prc_auc': 0.07100997800581915}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_morgan_fps, test_morgan_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
