{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem, MACCSkeys\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.ensemble import GradientBoostingRegressor\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_lo_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "      <th>cluster</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>C=C(C)COc1ccccc1CN1CCC2(CC1)CCN(C(=O)c1ccncc1)CC2</td>\n",
       "      <td>5.794709</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C=CCOC[C@H]1CC[C@@H](N2CC(NC(=O)CNc3nn(C)c4ccc...</td>\n",
       "      <td>5.300943</td>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>C=CCO[C@H]1CC[C@@H](N2CC(NC(=O)CNc3n[nH]c4ccc(...</td>\n",
       "      <td>5.130710</td>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>CC(=O)N1CCC(C2N[C@@H](c3nc(-c4ccccc4)c[nH]3)Cc...</td>\n",
       "      <td>5.008730</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>CC(=O)NC1CCN(CCc2ccc(Oc3nc4ccccc4s3)cc2)CC1</td>\n",
       "      <td>5.045709</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>401</th>\n",
       "      <td>c1ccc(-c2c[nH]c([C@H]3Cc4c([nH]c5ccccc45)[C@@H...</td>\n",
       "      <td>6.419075</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>402</th>\n",
       "      <td>c1ccc(-c2c[nH]c([C@H]3Cc4c([nH]c5ccccc45)[C@H]...</td>\n",
       "      <td>6.136083</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>403</th>\n",
       "      <td>c1ccc(-c2ccc(-c3c[nH]c([C@H]4Cc5c([nH]c6ccccc5...</td>\n",
       "      <td>7.744727</td>\n",
       "      <td>34</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>404</th>\n",
       "      <td>c1ccc(CCCNCCN(c2ccccc2)c2ccccc2)cc1</td>\n",
       "      <td>6.217567</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>405</th>\n",
       "      <td>c1ccc(CCNCCN(c2ccccc2)c2ccccc2)cc1</td>\n",
       "      <td>7.196901</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>406 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                smiles     value  cluster\n",
       "0    C=C(C)COc1ccccc1CN1CCC2(CC1)CCN(C(=O)c1ccncc1)CC2  5.794709       20\n",
       "1    C=CCOC[C@H]1CC[C@@H](N2CC(NC(=O)CNc3nn(C)c4ccc...  5.300943       31\n",
       "2    C=CCO[C@H]1CC[C@@H](N2CC(NC(=O)CNc3n[nH]c4ccc(...  5.130710       31\n",
       "3    CC(=O)N1CCC(C2N[C@@H](c3nc(-c4ccccc4)c[nH]3)Cc...  5.008730       34\n",
       "4          CC(=O)NC1CCN(CCc2ccc(Oc3nc4ccccc4s3)cc2)CC1  5.045709       12\n",
       "..                                                 ...       ...      ...\n",
       "401  c1ccc(-c2c[nH]c([C@H]3Cc4c([nH]c5ccccc45)[C@@H...  6.419075       34\n",
       "402  c1ccc(-c2c[nH]c([C@H]3Cc4c([nH]c5ccccc45)[C@H]...  6.136083       34\n",
       "403  c1ccc(-c2ccc(-c3c[nH]c([C@H]4Cc5c([nH]c6ccccc5...  7.744727       34\n",
       "404                c1ccc(CCCNCCN(c2ccccc2)c2ccccc2)cc1  6.217567       11\n",
       "405                 c1ccc(CCNCCN(c2ccccc2)c2ccccc2)cc1  7.196901       11\n",
       "\n",
       "[406 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/lo/kcnh2/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/lo/kcnh2/test_1.csv', index_col=0)\n",
    "\n",
    "test"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spearman_scorer(clf, X, y):\n",
    "    if len(X) == len(train):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(train, y_pred)\n",
    "        return metrics['spearman']\n",
    "    elif len(X) == len(test):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(test, y_pred)\n",
    "        return metrics['spearman']\n",
    "    else:\n",
    "        raise ValueError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_gb_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    knn = GradientBoostingRegressor()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(knn, params, cv=pds, n_iter=30, refit=False, scoring=spearman_scorer, verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    knn = GradientBoostingRegressor(**best_params)\n",
    "    knn.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = knn.predict(test_fps)\n",
    "    test_metrics = get_lo_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=2, n_estimators=200, subsample=0.7;, score=0.164 total time=   1.4s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=50, subsample=0.7;, score=0.210 total time=   0.9s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=500, subsample=0.9;, score=0.187 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=500, subsample=0.9;, score=0.165 total time=   1.0s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=250, subsample=0.7;, score=0.223 total time=   0.9s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=5, n_estimators=150, subsample=0.9;, score=0.088 total time=   1.4s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=200, subsample=0.4;, score=0.034 total time=   1.3s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=200, subsample=1.0;, score=0.110 total time=   2.5s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=500, subsample=0.9;, score=0.197 total time=   3.9s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=50, subsample=1.0;, score=0.229 total time=   0.7s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=200, subsample=0.4;, score=0.074 total time=   0.9s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=250, subsample=0.9;, score=0.181 total time=   2.8s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=7, n_estimators=50, subsample=0.4;, score=0.185 total time=   0.7s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=None, min_samples_leaf=1, min_samples_split=7, n_estimators=50, subsample=0.7;, score=0.182 total time=   0.9s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=150, subsample=0.7;, score=0.248 total time=   1.7s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=5, n_estimators=100, subsample=0.7;, score=0.087 total time=   1.2s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=10, subsample=1.0;, score=-0.038 total time=   0.7s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=500, subsample=0.4;, score=0.097 total time=   1.0s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=150, subsample=1.0;, score=0.161 total time=   0.8s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=50, subsample=1.0;, score=0.271 total time=   0.7s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=7, n_estimators=250, subsample=0.4;, score=0.100 total time=   0.8s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=50, subsample=0.9;, score=0.135 total time=   1.0s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=250, subsample=1.0;, score=0.204 total time=   0.9s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=100, subsample=0.7;, score=0.182 total time=   0.8s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=150, subsample=1.0;, score=0.119 total time=   0.8s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=2, n_estimators=10, subsample=0.9;, score=0.014 total time=   0.7s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=500, subsample=0.9;, score=0.190 total time=   4.9s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=7, n_estimators=150, subsample=0.4;, score=0.171 total time=   0.8s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=50, subsample=0.7;, score=0.163 total time=   0.9s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=4, max_features=sqrt, min_samples_leaf=1, min_samples_split=7, n_estimators=250, subsample=1.0;, score=0.148 total time=   0.9s\n",
      "{'subsample': 1.0, 'n_estimators': 50, 'min_samples_split': 3, 'min_samples_leaf': 3, 'max_features': 'sqrt', 'max_depth': 2, 'learning_rate': 0.1}\n",
      "{'r2': -0.905262062981486, 'spearman': 0.18438295609076483, 'mae': 0.9625562635355533}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in test_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_maccs_fps, test_maccs_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in test_mols]\n",
    "\n",
    "    gb = GradientBoostingRegressor(\n",
    "        n_estimators=50,\n",
    "        subsample=1.0,\n",
    "        min_samples_split=3,\n",
    "        min_samples_leaf=3,\n",
    "        max_features='sqrt',\n",
    "        max_depth=2,\n",
    "        learning_rate=0.1\n",
    "    )\n",
    "    gb.fit(train_maccs_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = gb.predict(train_maccs_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = gb.predict(test_maccs_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/lo/kcnh2/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/lo/kcnh2/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/lo/kcnh2/gb_maccs/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/lo/kcnh2/gb_maccs/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
