{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.ensemble import GradientBoostingRegressor\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_lo_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "      <th>cluster</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Brc1ccc(-[n+]2cc[n+](Cc3ccccc3)cc2)c2cc[nH]c12</td>\n",
       "      <td>7.717691</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brc1cccc(N2CCN(Cc3cnn4ccccc34)CC2)n1</td>\n",
       "      <td>6.748370</td>\n",
       "      <td>26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>C#CC1=CCC(N(CCC)CCCCn2cc(-c3ccc(-c4ccccc4)cc3)...</td>\n",
       "      <td>6.490481</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>C#CCN(CCN1CCN(c2ccccc2)CC1)C1CCc2ccc(O)cc2C1</td>\n",
       "      <td>6.609065</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C1=C(c2ccccc2)CCN(Cc2cnn(-c3ccccc3)c2)C1</td>\n",
       "      <td>7.473269</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>262</th>\n",
       "      <td>c1ccc2c(c1)N=C(N1CCNCC1)c1ccccc1S2</td>\n",
       "      <td>7.420216</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>263</th>\n",
       "      <td>c1cnc(N2CCN(CCCOc3ccc(-c4nc5ccccc5[nH]4)cc3)CC...</td>\n",
       "      <td>6.568636</td>\n",
       "      <td>35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>264</th>\n",
       "      <td>c1cnc(N2CCN(CCCOc3ccc(-c4nc5ccccc5o4)cc3)CC2)nc1</td>\n",
       "      <td>6.701147</td>\n",
       "      <td>35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>265</th>\n",
       "      <td>c1cnc(N2CCN(Cc3c[nH]c4ncccc34)CC2)nc1</td>\n",
       "      <td>5.931443</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>266</th>\n",
       "      <td>c1nc2c(s1)CCN(CCCCN1CCc3ncsc3CC1)CC2</td>\n",
       "      <td>5.199931</td>\n",
       "      <td>36</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>267 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                smiles     value  cluster\n",
       "0       Brc1ccc(-[n+]2cc[n+](Cc3ccccc3)cc2)c2cc[nH]c12  7.717691       11\n",
       "1                 Brc1cccc(N2CCN(Cc3cnn4ccccc34)CC2)n1  6.748370       26\n",
       "2    C#CC1=CCC(N(CCC)CCCCn2cc(-c3ccc(-c4ccccc4)cc3)...  6.490481       14\n",
       "3         C#CCN(CCN1CCN(c2ccccc2)CC1)C1CCc2ccc(O)cc2C1  6.609065       32\n",
       "4             C1=C(c2ccccc2)CCN(Cc2cnn(-c3ccccc3)c2)C1  7.473269       12\n",
       "..                                                 ...       ...      ...\n",
       "262                 c1ccc2c(c1)N=C(N1CCNCC1)c1ccccc1S2  7.420216        6\n",
       "263  c1cnc(N2CCN(CCCOc3ccc(-c4nc5ccccc5[nH]4)cc3)CC...  6.568636       35\n",
       "264   c1cnc(N2CCN(CCCOc3ccc(-c4nc5ccccc5o4)cc3)CC2)nc1  6.701147       35\n",
       "265              c1cnc(N2CCN(Cc3c[nH]c4ncccc34)CC2)nc1  5.931443       12\n",
       "266               c1nc2c(s1)CCN(CCCCN1CCc3ncsc3CC1)CC2  5.199931       36\n",
       "\n",
       "[267 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/lo/drd2/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/lo/drd2/test_1.csv', index_col=0)\n",
    "\n",
    "test"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spearman_scorer(clf, X, y):\n",
    "    if len(X) == len(train):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(train, y_pred)\n",
    "        return metrics['spearman']\n",
    "    elif len(X) == len(test):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(test, y_pred)\n",
    "        return metrics['spearman']\n",
    "    else:\n",
    "        raise ValueError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_gb_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    knn = GradientBoostingRegressor()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(knn, params, cv=pds, n_iter=30, refit=False, scoring=spearman_scorer, verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    knn = GradientBoostingRegressor(**best_params)\n",
    "    knn.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = knn.predict(test_fps)\n",
    "    test_metrics = get_lo_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=50, subsample=0.4;, score=0.208 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=5, n_estimators=200, subsample=1.0;, score=0.229 total time=  13.9s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=3, n_estimators=500, subsample=0.9;, score=0.200 total time=  15.7s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=500, subsample=0.4;, score=0.228 total time=   1.6s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=500, subsample=0.7;, score=0.200 total time=  23.3s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=10, subsample=1.0;, score=0.071 total time=   1.1s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=100, subsample=1.0;, score=0.279 total time=   1.3s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=200, subsample=1.0;, score=0.183 total time=  13.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=50, subsample=0.9;, score=0.186 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=200, subsample=0.4;, score=0.214 total time=   1.3s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=5, n_estimators=10, subsample=0.9;, score=0.085 total time=   1.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=500, subsample=0.4;, score=0.181 total time=   1.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=50, subsample=0.9;, score=0.152 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=5, n_estimators=500, subsample=0.7;, score=0.271 total time=  22.9s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=100, subsample=1.0;, score=0.243 total time=   4.3s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=200, subsample=0.4;, score=0.031 total time=   6.0s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=200, subsample=1.0;, score=0.201 total time=   1.3s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=200, subsample=0.9;, score=0.223 total time=   9.7s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=500, subsample=0.7;, score=0.072 total time=  12.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=250, subsample=1.0;, score=0.240 total time=   9.1s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=250, subsample=0.4;, score=0.119 total time=   4.3s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=10, subsample=0.9;, score=0.098 total time=   1.1s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=10, subsample=0.7;, score=0.121 total time=   1.1s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=100, subsample=0.4;, score=0.162 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=100, subsample=1.0;, score=0.156 total time=   1.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=150, subsample=0.9;, score=0.250 total time=   1.3s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=50, subsample=0.4;, score=0.173 total time=   1.7s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=500, subsample=0.9;, score=0.161 total time=   1.7s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=5, n_estimators=10, subsample=0.7;, score=0.061 total time=   1.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=3, n_estimators=200, subsample=0.9;, score=0.185 total time=   6.9s\n",
      "{'subsample': 1.0, 'n_estimators': 100, 'min_samples_split': 3, 'min_samples_leaf': 1, 'max_features': 'sqrt', 'max_depth': 4, 'learning_rate': 1.0}\n",
      "{'r2': -0.8765205702086243, 'spearman': 0.13590358704373595, 'mae': 0.8393918242681755}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_morgan_fps, test_morgan_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "    gb = GradientBoostingRegressor(\n",
    "        n_estimators=100,\n",
    "        subsample=1.0,\n",
    "        min_samples_split=3,\n",
    "        min_samples_leaf=1,\n",
    "        max_features='sqrt',\n",
    "        max_depth=4,\n",
    "        learning_rate=1.0\n",
    "    )\n",
    "    gb.fit(train_morgan_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = gb.predict(train_morgan_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = gb.predict(test_morgan_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/lo/drd2/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/lo/drd2/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/lo/drd2/gb_ecfp4/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/lo/drd2/gb_ecfp4/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
