{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem, MACCSkeys\n",
    "from sklearn.svm import SVR\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_lo_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "      <th>cluster</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Brc1ccc(CNCCN2CCN(Cc3cc4ccccc4[nH]3)CC2)cc1</td>\n",
       "      <td>5.283913</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brc1ccc(N2CCN(Cc3ccccc3)CC2)c2cc[nH]c12</td>\n",
       "      <td>7.437357</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Brc1ccc(NCCN2CCN(CCc3c[nH]c4ccccc34)CC2)cc1</td>\n",
       "      <td>7.288705</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Brc1ccc(NCCN2CCN(Cc3cc4ccccc4[nH]3)CC2)cc1</td>\n",
       "      <td>6.035740</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C#Cc1cccn1C1CCN(Cc2ccccc2)CC1</td>\n",
       "      <td>5.190490</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2201</th>\n",
       "      <td>c1ccc(OCC2CN(Cc3c[nH]c4ccccc34)CCO2)cc1</td>\n",
       "      <td>6.396856</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2202</th>\n",
       "      <td>c1ccc(OCCCNCCOc2ccccc2)cc1</td>\n",
       "      <td>6.598272</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2203</th>\n",
       "      <td>c1ccc2c(C3CCNC3)cccc2c1</td>\n",
       "      <td>6.576754</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2204</th>\n",
       "      <td>c1ccc2c(c1)CCN1CCc3[nH]c4ccccc4c3C21</td>\n",
       "      <td>5.830620</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2205</th>\n",
       "      <td>c1cnc(N2CCN(CCCCN3CCc4ncsc4CC3)CC2)nc1</td>\n",
       "      <td>5.599827</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2206 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           smiles     value  cluster\n",
       "0     Brc1ccc(CNCCN2CCN(Cc3cc4ccccc4[nH]3)CC2)cc1  5.283913        0\n",
       "1         Brc1ccc(N2CCN(Cc3ccccc3)CC2)c2cc[nH]c12  7.437357        0\n",
       "2     Brc1ccc(NCCN2CCN(CCc3c[nH]c4ccccc34)CC2)cc1  7.288705        0\n",
       "3      Brc1ccc(NCCN2CCN(Cc3cc4ccccc4[nH]3)CC2)cc1  6.035740        0\n",
       "4                   C#Cc1cccn1C1CCN(Cc2ccccc2)CC1  5.190490        0\n",
       "...                                           ...       ...      ...\n",
       "2201      c1ccc(OCC2CN(Cc3c[nH]c4ccccc34)CCO2)cc1  6.396856        0\n",
       "2202                   c1ccc(OCCCNCCOc2ccccc2)cc1  6.598272        0\n",
       "2203                      c1ccc2c(C3CCNC3)cccc2c1  6.576754        0\n",
       "2204         c1ccc2c(c1)CCN1CCc3[nH]c4ccccc4c3C21  5.830620        0\n",
       "2205       c1cnc(N2CCN(CCCCN3CCc4ncsc4CC3)CC2)nc1  5.599827        0\n",
       "\n",
       "[2206 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/lo/drd2/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/lo/drd2/test_1.csv', index_col=0)\n",
    "\n",
    "train"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spearman_scorer(clf, X, y):\n",
    "    if len(X) == len(train):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(train, y_pred)\n",
    "        return metrics['spearman']\n",
    "    elif len(X) == len(test):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(test, y_pred)\n",
    "        return metrics['spearman']\n",
    "    else:\n",
    "        raise ValueError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_svc_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "        'C': [0.1, 0.5, 1.0, 2.0, 5.0, 7.0, 10.0, 12.0, 15.0, 17.0, 20.0],\n",
    "    }\n",
    "    svc = SVR()\n",
    "\n",
    "    grid_search = GridSearchCV(svc, params, cv=pds, refit=False, scoring=spearman_scorer, verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    svc = SVR(**best_params)\n",
    "    svc.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = svc.predict(test_fps)\n",
    "    test_metrics = get_lo_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 11 candidates, totalling 11 fits\n",
      "[CV 1/1] END .............................C=0.1;, score=0.181 total time=   0.8s\n",
      "[CV 1/1] END .............................C=0.5;, score=0.178 total time=   0.7s\n",
      "[CV 1/1] END .............................C=1.0;, score=0.211 total time=   0.8s\n",
      "[CV 1/1] END .............................C=2.0;, score=0.171 total time=   0.8s\n",
      "[CV 1/1] END .............................C=5.0;, score=0.200 total time=   0.8s\n",
      "[CV 1/1] END .............................C=7.0;, score=0.167 total time=   0.8s\n",
      "[CV 1/1] END ............................C=10.0;, score=0.184 total time=   0.8s\n",
      "[CV 1/1] END ............................C=12.0;, score=0.170 total time=   0.8s\n",
      "[CV 1/1] END ............................C=15.0;, score=0.166 total time=   0.9s\n",
      "[CV 1/1] END ............................C=17.0;, score=0.169 total time=   0.9s\n",
      "[CV 1/1] END ............................C=20.0;, score=0.172 total time=   0.9s\n",
      "{'C': 1.0}\n",
      "{'r2': -0.5670737668511453, 'spearman': 0.21067958373609103, 'mae': 0.7997076616566108}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in test_mols]\n",
    "\n",
    "test_metrics = run_svc_gridsearch(train_maccs_fps, test_maccs_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in test_mols]\n",
    "\n",
    "    svc = SVR(\n",
    "        C=1.0\n",
    "    )\n",
    "    svc.fit(train_maccs_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = svc.predict(train_maccs_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = svc.predict(test_maccs_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/lo/drd2/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/lo/drd2/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/lo/drd2/svr_maccs/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/lo/drd2/svr_maccs/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
