{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.ensemble import GradientBoostingRegressor\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_lo_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "      <th>cluster</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>C=C(C)COc1ccccc1CN1CCC2(CC1)CCN(C(=O)c1ccncc1)CC2</td>\n",
       "      <td>5.794709</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C=CCOC[C@H]1CC[C@@H](N2CC(NC(=O)CNc3nn(C)c4ccc...</td>\n",
       "      <td>5.300943</td>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>C=CCO[C@H]1CC[C@@H](N2CC(NC(=O)CNc3n[nH]c4ccc(...</td>\n",
       "      <td>5.130710</td>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>CC(=O)N1CCC(C2N[C@@H](c3nc(-c4ccccc4)c[nH]3)Cc...</td>\n",
       "      <td>5.008730</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>CC(=O)NC1CCN(CCc2ccc(Oc3nc4ccccc4s3)cc2)CC1</td>\n",
       "      <td>5.045709</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>401</th>\n",
       "      <td>c1ccc(-c2c[nH]c([C@H]3Cc4c([nH]c5ccccc45)[C@@H...</td>\n",
       "      <td>6.419075</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>402</th>\n",
       "      <td>c1ccc(-c2c[nH]c([C@H]3Cc4c([nH]c5ccccc45)[C@H]...</td>\n",
       "      <td>6.136083</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>403</th>\n",
       "      <td>c1ccc(-c2ccc(-c3c[nH]c([C@H]4Cc5c([nH]c6ccccc5...</td>\n",
       "      <td>7.744727</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>404</th>\n",
       "      <td>c1ccc(CCCNCCN(c2ccccc2)c2ccccc2)cc1</td>\n",
       "      <td>6.217567</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>405</th>\n",
       "      <td>c1ccc(CCNCCN(c2ccccc2)c2ccccc2)cc1</td>\n",
       "      <td>7.196901</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>406 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                smiles     value  cluster\n",
       "0    C=C(C)COc1ccccc1CN1CCC2(CC1)CCN(C(=O)c1ccncc1)CC2  5.794709       20\n",
       "1    C=CCOC[C@H]1CC[C@@H](N2CC(NC(=O)CNc3nn(C)c4ccc...  5.300943       31\n",
       "2    C=CCO[C@H]1CC[C@@H](N2CC(NC(=O)CNc3n[nH]c4ccc(...  5.130710       31\n",
       "3    CC(=O)N1CCC(C2N[C@@H](c3nc(-c4ccccc4)c[nH]3)Cc...  5.008730       34\n",
       "4          CC(=O)NC1CCN(CCc2ccc(Oc3nc4ccccc4s3)cc2)CC1  5.045709       12\n",
       "..                                                 ...       ...      ...\n",
       "401  c1ccc(-c2c[nH]c([C@H]3Cc4c([nH]c5ccccc45)[C@@H...  6.419075       34\n",
       "402  c1ccc(-c2c[nH]c([C@H]3Cc4c([nH]c5ccccc45)[C@H]...  6.136083       34\n",
       "403  c1ccc(-c2ccc(-c3c[nH]c([C@H]4Cc5c([nH]c6ccccc5...  7.744727       34\n",
       "404                c1ccc(CCCNCCN(c2ccccc2)c2ccccc2)cc1  6.217567       11\n",
       "405                 c1ccc(CCNCCN(c2ccccc2)c2ccccc2)cc1  7.196901       11\n",
       "\n",
       "[406 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/lo/kcnh2/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/lo/kcnh2/test_1.csv', index_col=0)\n",
    "\n",
    "test"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spearman_scorer(clf, X, y):\n",
    "    if len(X) == len(train):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(train, y_pred)\n",
    "        return metrics['spearman']\n",
    "    elif len(X) == len(test):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(test, y_pred)\n",
    "        return metrics['spearman']\n",
    "    else:\n",
    "        raise ValueError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_gb_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    knn = GradientBoostingRegressor()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(knn, params, cv=pds, n_iter=30, refit=False, scoring=spearman_scorer, verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    knn = GradientBoostingRegressor(**best_params)\n",
    "    knn.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = knn.predict(test_fps)\n",
    "    test_metrics = get_lo_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=5, n_estimators=250, subsample=0.7;, score=0.359 total time=  18.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=500, subsample=0.4;, score=0.229 total time=   2.1s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=10, subsample=0.7;, score=0.156 total time=   1.5s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=500, subsample=0.7;, score=0.411 total time=  35.6s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=100, subsample=0.7;, score=0.382 total time=   8.3s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=250, subsample=0.4;, score=0.285 total time=   1.8s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=10, subsample=0.4;, score=0.001 total time=   1.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=500, subsample=0.4;, score=-0.047 total time=  11.4s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=200, subsample=1.0;, score=0.207 total time=   1.9s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=250, subsample=0.7;, score=0.206 total time=  14.4s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=500, subsample=0.7;, score=0.370 total time=   2.9s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=500, subsample=0.4;, score=0.253 total time=  11.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=500, subsample=1.0;, score=0.237 total time=   2.5s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=50, subsample=0.4;, score=0.089 total time=   3.4s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=50, subsample=0.4;, score=0.201 total time=   1.6s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=10, subsample=0.4;, score=0.172 total time=   1.8s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=150, subsample=0.9;, score=0.193 total time=   8.2s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=7, n_estimators=10, subsample=0.4;, score=0.052 total time=   1.5s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=200, subsample=0.4;, score=0.321 total time=   1.9s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=250, subsample=0.7;, score=0.280 total time=   1.9s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=100, subsample=0.7;, score=0.323 total time=   1.8s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=5, n_estimators=100, subsample=0.4;, score=0.264 total time=   3.5s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=50, subsample=0.4;, score=0.144 total time=   2.5s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=250, subsample=1.0;, score=0.220 total time=   2.2s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=5, n_estimators=50, subsample=0.7;, score=0.261 total time=   3.3s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=100, subsample=0.4;, score=0.177 total time=   4.4s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=50, subsample=1.0;, score=0.294 total time=   5.2s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=10, subsample=0.7;, score=0.106 total time=   1.8s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=150, subsample=0.4;, score=0.127 total time=   5.9s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=7, n_estimators=500, subsample=0.7;, score=0.306 total time=   2.9s\n",
      "{'subsample': 0.7, 'n_estimators': 500, 'min_samples_split': 3, 'min_samples_leaf': 5, 'max_features': None, 'max_depth': 4, 'learning_rate': 0.01}\n",
      "{'r2': -0.8798180876954403, 'spearman': 0.3805236471923937, 'mae': 0.9609210981670802}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_morgan_fps, test_morgan_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "    gb = GradientBoostingRegressor(\n",
    "        n_estimators=500,\n",
    "        subsample=0.7,\n",
    "        min_samples_split=3,\n",
    "        min_samples_leaf=5,\n",
    "        max_features=None,\n",
    "        max_depth=4,\n",
    "        learning_rate=0.01\n",
    "    )\n",
    "    gb.fit(train_morgan_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = gb.predict(train_morgan_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = gb.predict(test_morgan_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/lo/kcnh2/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/lo/kcnh2/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/lo/kcnh2/gb_ecfp4/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/lo/kcnh2/gb_ecfp4/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
