{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem, MACCSkeys\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.ensemble import GradientBoostingRegressor\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_lo_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "      <th>cluster</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Brc1ccc(-c2nc3ccc(Nc4ncnc5ccccc45)cc3[nH]2)cc1</td>\n",
       "      <td>6.419075</td>\n",
       "      <td>51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3ccc(F)cc3)cc2)cn1</td>\n",
       "      <td>8.047208</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3cccc(C(C)C)c3)c...</td>\n",
       "      <td>8.508638</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3cccc(Cl)c3)cc2)cn1</td>\n",
       "      <td>8.474955</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C=CC(=O)Nc1cccc(-c2ccc(NC(=O)Nc3c(C)cccc3C)cc2)n1</td>\n",
       "      <td>6.380687</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>432</th>\n",
       "      <td>c1ccc2c(-c3cnn4cc(-c5ccc(N6CCNCC6)cc5)cnc34)cc...</td>\n",
       "      <td>6.666150</td>\n",
       "      <td>39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>433</th>\n",
       "      <td>c1ccc2c(-c3cnn4cc(-c5ccc(N6CCOCC6)cc5)cnc34)cc...</td>\n",
       "      <td>5.273191</td>\n",
       "      <td>39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>434</th>\n",
       "      <td>c1ccc2c(-c3cnn4cc(-c5ccc(OCCN6CCOCC6)cc5)cnc34...</td>\n",
       "      <td>5.616364</td>\n",
       "      <td>39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>435</th>\n",
       "      <td>c1ccc2c(-c3nc4cc(-n5ccnc5)ccc4[nH]3)[nH]nc2c1</td>\n",
       "      <td>7.075721</td>\n",
       "      <td>45</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>436</th>\n",
       "      <td>c1ccc2c(-c3nc4cc(N5CCOCC5)ccc4[nH]3)[nH]nc2c1</td>\n",
       "      <td>6.793174</td>\n",
       "      <td>45</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>437 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                smiles     value  cluster\n",
       "0       Brc1ccc(-c2nc3ccc(Nc4ncnc5ccccc45)cc3[nH]2)cc1  6.419075       51\n",
       "1     C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3ccc(F)cc3)cc2)cn1  8.047208       32\n",
       "2    C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3cccc(C(C)C)c3)c...  8.508638       32\n",
       "3    C=CC(=O)Nc1ccc(-c2ccc(NC(=O)Nc3cccc(Cl)c3)cc2)cn1  8.474955       32\n",
       "4    C=CC(=O)Nc1cccc(-c2ccc(NC(=O)Nc3c(C)cccc3C)cc2)n1  6.380687       32\n",
       "..                                                 ...       ...      ...\n",
       "432  c1ccc2c(-c3cnn4cc(-c5ccc(N6CCNCC6)cc5)cnc34)cc...  6.666150       39\n",
       "433  c1ccc2c(-c3cnn4cc(-c5ccc(N6CCOCC6)cc5)cnc34)cc...  5.273191       39\n",
       "434  c1ccc2c(-c3cnn4cc(-c5ccc(OCCN6CCOCC6)cc5)cnc34...  5.616364       39\n",
       "435      c1ccc2c(-c3nc4cc(-n5ccnc5)ccc4[nH]3)[nH]nc2c1  7.075721       45\n",
       "436      c1ccc2c(-c3nc4cc(N5CCOCC5)ccc4[nH]3)[nH]nc2c1  6.793174       45\n",
       "\n",
       "[437 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/lo/kdr/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/lo/kdr/test_1.csv', index_col=0)\n",
    "\n",
    "test"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spearman_scorer(clf, X, y):\n",
    "    if len(X) == len(train):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(train, y_pred)\n",
    "        return metrics['spearman']\n",
    "    elif len(X) == len(test):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(test, y_pred)\n",
    "        return metrics['spearman']\n",
    "    else:\n",
    "        raise ValueError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_gb_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    knn = GradientBoostingRegressor()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(knn, params, cv=pds, n_iter=30, refit=False, scoring=spearman_scorer, verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    knn = GradientBoostingRegressor(**best_params)\n",
    "    knn.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = knn.predict(test_fps)\n",
    "    test_metrics = get_lo_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=250, subsample=0.7;, score=0.165 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=10, subsample=0.4;, score=0.021 total time=   0.1s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=150, subsample=1.0;, score=-0.017 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=500, subsample=0.9;, score=0.129 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=500, subsample=0.7;, score=0.014 total time=   0.4s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=50, subsample=1.0;, score=0.093 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=200, subsample=1.0;, score=0.065 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=150, subsample=0.4;, score=0.078 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=500, subsample=0.9;, score=0.055 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=50, subsample=0.4;, score=0.041 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=5, n_estimators=200, subsample=0.4;, score=0.035 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=100, subsample=0.7;, score=0.067 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=5, n_estimators=100, subsample=1.0;, score=0.089 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=150, subsample=1.0;, score=0.095 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=5, n_estimators=500, subsample=0.7;, score=0.134 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=50, subsample=0.7;, score=0.021 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=50, subsample=0.9;, score=0.092 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=100, subsample=1.0;, score=0.004 total time=   0.2s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=50, subsample=0.9;, score=0.087 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=200, subsample=0.7;, score=0.049 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=150, subsample=0.9;, score=0.153 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=3, n_estimators=500, subsample=0.4;, score=0.069 total time=   0.5s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=500, subsample=0.9;, score=0.146 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=150, subsample=1.0;, score=0.080 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=200, subsample=0.7;, score=0.089 total time=   0.2s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=500, subsample=1.0;, score=0.064 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=150, subsample=0.9;, score=0.072 total time=   0.4s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=100, subsample=1.0;, score=0.074 total time=   0.3s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=250, subsample=0.9;, score=0.098 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=100, subsample=0.7;, score=0.042 total time=   0.2s\n",
      "{'subsample': 0.7, 'n_estimators': 250, 'min_samples_split': 3, 'min_samples_leaf': 1, 'max_features': None, 'max_depth': 3, 'learning_rate': 0.1}\n",
      "{'r2': -1.1448003934273825, 'spearman': 0.12070731790875343, 'mae': 0.8885389200375444}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in test_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_maccs_fps, test_maccs_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in test_mols]\n",
    "\n",
    "    gb = GradientBoostingRegressor(\n",
    "        n_estimators=250,\n",
    "        subsample=0.7,\n",
    "        min_samples_split=3,\n",
    "        min_samples_leaf=1,\n",
    "        max_features=None,\n",
    "        max_depth=3,\n",
    "        learning_rate=0.1\n",
    "    )\n",
    "    gb.fit(train_maccs_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = gb.predict(train_maccs_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = gb.predict(test_maccs_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/lo/kdr/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/lo/kdr/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/lo/kdr/gb_maccs/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/lo/kdr/gb_maccs/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
