{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem, MACCSkeys\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.ensemble import GradientBoostingRegressor\n",
    "from sklearn.model_selection import RandomizedSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_lo_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "      <th>cluster</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Brc1ccc(-[n+]2cc[n+](Cc3ccccc3)cc2)c2cc[nH]c12</td>\n",
       "      <td>7.717691</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brc1cccc(N2CCN(Cc3cnn4ccccc34)CC2)n1</td>\n",
       "      <td>6.748370</td>\n",
       "      <td>26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>C#CC1=CCC(N(CCC)CCCCn2cc(-c3ccc(-c4ccccc4)cc3)...</td>\n",
       "      <td>6.490481</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>C#CCN(CCN1CCN(c2ccccc2)CC1)C1CCc2ccc(O)cc2C1</td>\n",
       "      <td>6.609065</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C1=C(c2ccccc2)CCN(Cc2cnn(-c3ccccc3)c2)C1</td>\n",
       "      <td>7.473269</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>262</th>\n",
       "      <td>c1ccc2c(c1)N=C(N1CCNCC1)c1ccccc1S2</td>\n",
       "      <td>7.420216</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>263</th>\n",
       "      <td>c1cnc(N2CCN(CCCOc3ccc(-c4nc5ccccc5[nH]4)cc3)CC...</td>\n",
       "      <td>6.568636</td>\n",
       "      <td>35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>264</th>\n",
       "      <td>c1cnc(N2CCN(CCCOc3ccc(-c4nc5ccccc5o4)cc3)CC2)nc1</td>\n",
       "      <td>6.701147</td>\n",
       "      <td>35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>265</th>\n",
       "      <td>c1cnc(N2CCN(Cc3c[nH]c4ncccc34)CC2)nc1</td>\n",
       "      <td>5.931443</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>266</th>\n",
       "      <td>c1nc2c(s1)CCN(CCCCN1CCc3ncsc3CC1)CC2</td>\n",
       "      <td>5.199931</td>\n",
       "      <td>36</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>267 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                smiles     value  cluster\n",
       "0       Brc1ccc(-[n+]2cc[n+](Cc3ccccc3)cc2)c2cc[nH]c12  7.717691       11\n",
       "1                 Brc1cccc(N2CCN(Cc3cnn4ccccc34)CC2)n1  6.748370       26\n",
       "2    C#CC1=CCC(N(CCC)CCCCn2cc(-c3ccc(-c4ccccc4)cc3)...  6.490481       14\n",
       "3         C#CCN(CCN1CCN(c2ccccc2)CC1)C1CCc2ccc(O)cc2C1  6.609065       32\n",
       "4             C1=C(c2ccccc2)CCN(Cc2cnn(-c3ccccc3)c2)C1  7.473269       12\n",
       "..                                                 ...       ...      ...\n",
       "262                 c1ccc2c(c1)N=C(N1CCNCC1)c1ccccc1S2  7.420216        6\n",
       "263  c1cnc(N2CCN(CCCOc3ccc(-c4nc5ccccc5[nH]4)cc3)CC...  6.568636       35\n",
       "264   c1cnc(N2CCN(CCCOc3ccc(-c4nc5ccccc5o4)cc3)CC2)nc1  6.701147       35\n",
       "265              c1cnc(N2CCN(Cc3c[nH]c4ncccc34)CC2)nc1  5.931443       12\n",
       "266               c1nc2c(s1)CCN(CCCCN1CCc3ncsc3CC1)CC2  5.199931       36\n",
       "\n",
       "[267 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/lo/drd2/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/lo/drd2/test_1.csv', index_col=0)\n",
    "\n",
    "test"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spearman_scorer(clf, X, y):\n",
    "    if len(X) == len(train):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(train, y_pred)\n",
    "        return metrics['spearman']\n",
    "    elif len(X) == len(test):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(test, y_pred)\n",
    "        return metrics['spearman']\n",
    "    else:\n",
    "        raise ValueError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_gb_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "    'n_estimators': [10, 50, 100, 150, 200, 250, 500],\n",
    "    'learning_rate': [0.01, 0.1, 0.3, 0.5, 0.7, 1.0],\n",
    "    'subsample': [0.4, 0.7, 0.9, 1.0],\n",
    "    'min_samples_split': [2, 3, 5, 7],\n",
    "    'min_samples_leaf': [1, 3, 5],\n",
    "    'max_depth': [2, 3, 4],\n",
    "    'max_features': [None, 'sqrt']\n",
    "    }\n",
    "    knn = GradientBoostingRegressor()\n",
    "\n",
    "    grid_search = RandomizedSearchCV(knn, params, cv=pds, n_iter=30, refit=False, scoring=spearman_scorer, verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    knn = GradientBoostingRegressor(**best_params)\n",
    "    knn.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = knn.predict(test_fps)\n",
    "    test_metrics = get_lo_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 30 candidates, totalling 30 fits\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=100, subsample=1.0;, score=0.184 total time=   0.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=10, subsample=1.0;, score=0.106 total time=   0.5s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=7, n_estimators=500, subsample=0.4;, score=0.149 total time=   1.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=250, subsample=0.4;, score=-0.023 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=500, subsample=0.7;, score=0.160 total time=   0.7s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=200, subsample=0.9;, score=0.242 total time=   1.5s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=2, n_estimators=100, subsample=1.0;, score=0.035 total time=   0.8s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=5, n_estimators=50, subsample=0.4;, score=0.138 total time=   0.5s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=100, subsample=0.9;, score=0.231 total time=   0.5s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=4, max_features=None, min_samples_leaf=3, min_samples_split=3, n_estimators=50, subsample=1.0;, score=0.224 total time=   0.7s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=4, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=200, subsample=0.9;, score=0.091 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=2, n_estimators=50, subsample=0.7;, score=0.003 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=5, n_estimators=200, subsample=0.7;, score=0.168 total time=   0.9s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=250, subsample=1.0;, score=0.073 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=2, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=500, subsample=0.7;, score=0.179 total time=   0.7s\n",
      "[CV 1/1] END learning_rate=1.0, max_depth=3, max_features=sqrt, min_samples_leaf=5, min_samples_split=3, n_estimators=200, subsample=0.4;, score=0.170 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=150, subsample=0.9;, score=0.164 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=10, subsample=1.0;, score=0.070 total time=   0.5s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=3, n_estimators=250, subsample=0.9;, score=0.166 total time=   0.6s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n",
      "/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/scipy/stats/_stats_py.py:4529: SpearmanRConstantInputWarning: An input array is constant; the correlation coefficient is not defined.\n",
      "  warnings.warn(SpearmanRConstantInputWarning())\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CV 1/1] END learning_rate=0.01, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=150, subsample=0.9;, score=0.035 total time=   0.9s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=sqrt, min_samples_leaf=1, min_samples_split=5, n_estimators=200, subsample=1.0;, score=0.214 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=3, max_features=None, min_samples_leaf=3, min_samples_split=7, n_estimators=200, subsample=0.4;, score=0.108 total time=   0.9s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=1, min_samples_split=3, n_estimators=10, subsample=0.4;, score=0.090 total time=   0.5s\n",
      "[CV 1/1] END learning_rate=0.3, max_depth=3, max_features=sqrt, min_samples_leaf=3, min_samples_split=3, n_estimators=250, subsample=0.4;, score=0.210 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=None, min_samples_leaf=3, min_samples_split=5, n_estimators=250, subsample=0.4;, score=0.187 total time=   0.8s\n",
      "[CV 1/1] END learning_rate=0.01, max_depth=4, max_features=sqrt, min_samples_leaf=5, min_samples_split=2, n_estimators=50, subsample=0.4;, score=0.049 total time=   0.5s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=2, n_estimators=200, subsample=0.7;, score=0.211 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.7, max_depth=2, max_features=sqrt, min_samples_leaf=3, min_samples_split=5, n_estimators=250, subsample=1.0;, score=0.114 total time=   0.6s\n",
      "[CV 1/1] END learning_rate=0.5, max_depth=2, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=10, subsample=0.9;, score=0.117 total time=   0.5s\n",
      "[CV 1/1] END learning_rate=0.1, max_depth=3, max_features=None, min_samples_leaf=5, min_samples_split=3, n_estimators=50, subsample=0.9;, score=0.209 total time=   0.7s\n",
      "{'subsample': 0.9, 'n_estimators': 200, 'min_samples_split': 3, 'min_samples_leaf': 5, 'max_features': None, 'max_depth': 4, 'learning_rate': 0.3}\n",
      "{'r2': -0.5560390015045796, 'spearman': 0.13818785383684568, 'mae': 0.758515111098191}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in test_mols]\n",
    "\n",
    "test_metrics = run_gb_gridsearch(train_maccs_fps, test_maccs_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in test_mols]\n",
    "\n",
    "    gb = GradientBoostingRegressor(\n",
    "        n_estimators=200,\n",
    "        subsample=0.9,\n",
    "        min_samples_split=3,\n",
    "        min_samples_leaf=5,\n",
    "        max_features=None,\n",
    "        max_depth=4,\n",
    "        learning_rate=0.3\n",
    "    )\n",
    "    gb.fit(train_maccs_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = gb.predict(train_maccs_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = gb.predict(test_maccs_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/lo/drd2/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/lo/drd2/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/lo/drd2/gb_maccs/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/lo/drd2/gb_maccs/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
