{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem\n",
    "from sklearn.svm import SVR\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/steshin/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_lo_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "      <th>cluster</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Brc1ccc(CNCCN2CCN(Cc3cc4ccccc4[nH]3)CC2)cc1</td>\n",
       "      <td>5.283913</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brc1ccc(N2CCN(Cc3ccccc3)CC2)c2cc[nH]c12</td>\n",
       "      <td>7.437357</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Brc1ccc(NCCN2CCN(CCc3c[nH]c4ccccc34)CC2)cc1</td>\n",
       "      <td>7.288705</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Brc1ccc(NCCN2CCN(Cc3cc4ccccc4[nH]3)CC2)cc1</td>\n",
       "      <td>6.035740</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C#Cc1cccn1C1CCN(Cc2ccccc2)CC1</td>\n",
       "      <td>5.190490</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2201</th>\n",
       "      <td>c1ccc(OCC2CN(Cc3c[nH]c4ccccc34)CCO2)cc1</td>\n",
       "      <td>6.396856</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2202</th>\n",
       "      <td>c1ccc(OCCCNCCOc2ccccc2)cc1</td>\n",
       "      <td>6.598272</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2203</th>\n",
       "      <td>c1ccc2c(C3CCNC3)cccc2c1</td>\n",
       "      <td>6.576754</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2204</th>\n",
       "      <td>c1ccc2c(c1)CCN1CCc3[nH]c4ccccc4c3C21</td>\n",
       "      <td>5.830620</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2205</th>\n",
       "      <td>c1cnc(N2CCN(CCCCN3CCc4ncsc4CC3)CC2)nc1</td>\n",
       "      <td>5.599827</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2206 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           smiles     value  cluster\n",
       "0     Brc1ccc(CNCCN2CCN(Cc3cc4ccccc4[nH]3)CC2)cc1  5.283913        0\n",
       "1         Brc1ccc(N2CCN(Cc3ccccc3)CC2)c2cc[nH]c12  7.437357        0\n",
       "2     Brc1ccc(NCCN2CCN(CCc3c[nH]c4ccccc34)CC2)cc1  7.288705        0\n",
       "3      Brc1ccc(NCCN2CCN(Cc3cc4ccccc4[nH]3)CC2)cc1  6.035740        0\n",
       "4                   C#Cc1cccn1C1CCN(Cc2ccccc2)CC1  5.190490        0\n",
       "...                                           ...       ...      ...\n",
       "2201      c1ccc(OCC2CN(Cc3c[nH]c4ccccc34)CCO2)cc1  6.396856        0\n",
       "2202                   c1ccc(OCCCNCCOc2ccccc2)cc1  6.598272        0\n",
       "2203                      c1ccc2c(C3CCNC3)cccc2c1  6.576754        0\n",
       "2204         c1ccc2c(c1)CCN1CCc3[nH]c4ccccc4c3C21  5.830620        0\n",
       "2205       c1cnc(N2CCN(CCCCN3CCc4ncsc4CC3)CC2)nc1  5.599827        0\n",
       "\n",
       "[2206 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/lo/drd2/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/lo/drd2/test_1.csv', index_col=0)\n",
    "\n",
    "train"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spearman_scorer(clf, X, y):\n",
    "    if len(X) == len(train):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(train, y_pred)\n",
    "        return metrics['spearman']\n",
    "    elif len(X) == len(test):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(test, y_pred)\n",
    "        return metrics['spearman']\n",
    "    else:\n",
    "        raise ValueError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_svc_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "        'C': [0.1, 0.5, 1.0, 2.0, 5.0],\n",
    "    }\n",
    "    svc = SVR()\n",
    "\n",
    "    grid_search = GridSearchCV(svc, params, cv=pds, refit=False, scoring=spearman_scorer, verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    svc = SVR(**best_params)\n",
    "    svc.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = svc.predict(test_fps)\n",
    "    test_metrics = get_lo_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 5 candidates, totalling 5 fits\n",
      "[CV 1/1] END .............................C=0.1;, score=0.215 total time=   2.6s\n",
      "[CV 1/1] END .............................C=0.5;, score=0.254 total time=   2.4s\n",
      "[CV 1/1] END .............................C=1.0;, score=0.295 total time=   2.3s\n",
      "[CV 1/1] END .............................C=2.0;, score=0.332 total time=   2.4s\n",
      "[CV 1/1] END .............................C=5.0;, score=0.329 total time=   2.4s\n",
      "{'C': 2.0}\n",
      "{'r2': -0.4028261796944916, 'spearman': 0.3323541330588398, 'mae': 0.7436238432587361}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "test_metrics = run_svc_gridsearch(train_morgan_fps, test_morgan_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "    svc = SVR(\n",
    "        C=2.0\n",
    "    )\n",
    "    svc.fit(train_morgan_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = svc.predict(train_morgan_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = svc.predict(test_morgan_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/lo/drd2/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/lo/drd2/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/lo/drd2/svr_ecfp4/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/lo/drd2/svr_ecfp4/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
