{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.ensemble import GradientBoostingRegressor\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_lo_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "      <th>cluster</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Brc1ccc(-c2nc3ccc(Nc4ncnc5ccccc45)cc3[nH]2)cc1</td>\n",
       "      <td>6.419075</td>\n",
       "      <td>51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3ccc(F)cc3)cc2)cn1</td>\n",
       "      <td>8.047208</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3cccc(C(C)C)c3)c...</td>\n",
       "      <td>8.508638</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3cccc(Cl)c3)cc2)cn1</td>\n",
       "      <td>8.474955</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C=CC(=O)Nc1cccc(-c2ccc(NC(=O)Nc3c(C)cccc3C)cc2)n1</td>\n",
       "      <td>6.380687</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>432</th>\n",
       "      <td>c1ccc2c(-c3cnn4cc(-c5ccc(N6CCNCC6)cc5)cnc34)cc...</td>\n",
       "      <td>6.666150</td>\n",
       "      <td>39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>433</th>\n",
       "      <td>c1ccc2c(-c3cnn4cc(-c5ccc(N6CCOCC6)cc5)cnc34)cc...</td>\n",
       "      <td>5.273191</td>\n",
       "      <td>39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>434</th>\n",
       "      <td>c1ccc2c(-c3cnn4cc(-c5ccc(OCCN6CCOCC6)cc5)cnc34...</td>\n",
       "      <td>5.616364</td>\n",
       "      <td>39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>435</th>\n",
       "      <td>c1ccc2c(-c3nc4cc(-n5ccnc5)ccc4[nH]3)[nH]nc2c1</td>\n",
       "      <td>7.075721</td>\n",
       "      <td>45</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>436</th>\n",
       "      <td>c1ccc2c(-c3nc4cc(N5CCOCC5)ccc4[nH]3)[nH]nc2c1</td>\n",
       "      <td>6.793174</td>\n",
       "      <td>45</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>437 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                smiles     value  cluster\n",
       "0       Brc1ccc(-c2nc3ccc(Nc4ncnc5ccccc45)cc3[nH]2)cc1  6.419075       51\n",
       "1     C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3ccc(F)cc3)cc2)cn1  8.047208       32\n",
       "2    C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3cccc(C(C)C)c3)c...  8.508638       32\n",
       "3    C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3cccc(Cl)c3)cc2)cn1  8.474955       32\n",
       "4    C=CC(=O)Nc1cccc(-c2ccc(NC(=O)Nc3c(C)cccc3C)cc2)n1  6.380687       32\n",
       "..                                                 ...       ...      ...\n",
       "432  c1ccc2c(-c3cnn4cc(-c5ccc(N6CCNCC6)cc5)cnc34)cc...  6.666150       39\n",
       "433  c1ccc2c(-c3cnn4cc(-c5ccc(N6CCOCC6)cc5)cnc34)cc...  5.273191       39\n",
       "434  c1ccc2c(-c3cnn4cc(-c5ccc(OCCN6CCOCC6)cc5)cnc34...  5.616364       39\n",
       "435      c1ccc2c(-c3nc4cc(-n5ccnc5)ccc4[nH]3)[nH]nc2c1  7.075721       45\n",
       "436      c1ccc2c(-c3nc4cc(N5CCOCC5)ccc4[nH]3)[nH]nc2c1  6.793174       45\n",
       "\n",
       "[437 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/lo/kdr/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/lo/kdr/test_1.csv', index_col=0)\n",
    "\n",
    "test"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spearman_scorer(clf, X, y):\n",
    "    if len(X) == len(train):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(train, y_pred)\n",
    "        return metrics['spearman']\n",
    "    elif len(X) == len(test):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(test, y_pred)\n",
    "        return metrics['spearman']\n",
    "    else:\n",
    "        raise ValueError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_gb_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    knn = GradientBoostingRegressor()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(knn, params, cv=pds, n_iter=30, refit=False, scoring=spearman_scorer, verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    knn = GradientBoostingRegressor(**best_params)\n",
    "    knn.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = knn.predict(test_fps)\n",
    "    test_metrics = get_lo_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=250, subsample=0.4;, score=0.058 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=100, subsample=0.9;, score=0.095 total time=   0.8s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=50, subsample=0.9;, score=0.152 total time=   0.2s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=10, subsample=0.9;, score=0.048 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=200, subsample=0.7;, score=0.032 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=500, subsample=0.4;, score=0.136 total time=   0.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=10, subsample=0.7;, score=0.042 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=50, subsample=0.4;, score=0.038 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=100, subsample=1.0;, score=0.162 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=7, n_estimators=500, subsample=0.9;, score=0.135 total time=   0.4s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=50, subsample=0.4;, score=0.165 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=500, subsample=0.9;, score=0.133 total time=   0.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=5, n_estimators=10, subsample=0.7;, score=-0.033 total time=   0.3s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=10, subsample=0.9;, score=0.067 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=100, subsample=1.0;, score=0.124 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=200, subsample=0.4;, score=0.039 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=500, subsample=1.0;, score=0.149 total time=   0.4s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=100, subsample=0.9;, score=0.042 total time=   0.8s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=250, subsample=0.7;, score=0.006 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=500, subsample=0.9;, score=0.117 total time=   0.4s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=10, subsample=0.9;, score=0.078 total time=   0.3s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=100, subsample=1.0;, score=0.131 total time=   0.9s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=150, subsample=1.0;, score=0.168 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=150, subsample=0.7;, score=0.078 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=150, subsample=0.9;, score=0.102 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=5, n_estimators=10, subsample=0.4;, score=0.099 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=50, subsample=0.9;, score=0.096 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=150, subsample=0.4;, score=0.112 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=250, subsample=0.9;, score=0.113 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=3, n_estimators=100, subsample=1.0;, score=0.155 total time=   0.7s\n",
      "{'subsample': 1.0, 'n_estimators': 150, 'min_samples_split': 3, 'min_samples_leaf': 5, 'max_features': 'sqrt', 'max_depth': 4, 'learning_rate': 0.3}\n",
      "{'r2': -1.8871029215532038, 'spearman': 0.040062490024974445, 'mae': 0.9202905058728743}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_morgan_fps, test_morgan_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "    gb = GradientBoostingRegressor(\n",
    "        n_estimators=150,\n",
    "        subsample=1.0,\n",
    "        min_samples_split=3,\n",
    "        min_samples_leaf=5,\n",
    "        max_features='sqrt',\n",
    "        max_depth=4,\n",
    "        learning_rate=0.3\n",
    "    )\n",
    "    gb.fit(train_morgan_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = gb.predict(train_morgan_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = gb.predict(test_morgan_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/lo/kdr/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/lo/kdr/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/lo/kdr/gb_ecfp4/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/lo/kdr/gb_ecfp4/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
