{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem, MACCSkeys\n",
    "from sklearn.svm import SVR\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_lo_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "      <th>cluster</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>C/C(=N\\OC(C)C)c1ccc2c(c1)c1c3c(c4c(c1n2CC(C)C)...</td>\n",
       "      <td>7.897940</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C/C(=N\\OCC(C)C)c1ccc2[nH]c3c4c(c5c(c3c2c1)CNC5...</td>\n",
       "      <td>8.129819</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>C=CC(=O)Nc1cc2c(Nc3c(F)cc(Br)cc3F)ncnc2cc1OCC1...</td>\n",
       "      <td>6.826814</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>C=CC(=O)Nc1cc2c(Nc3cc(Cl)c(Br)cc3F)ncnc2cc1OCC...</td>\n",
       "      <td>6.376751</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C=CC(=O)Nc1cc2c(Nc3cc(Cl)c(Cl)cc3Cl)ncnc2cc1OC...</td>\n",
       "      <td>6.102373</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>495</th>\n",
       "      <td>c1ccc(-c2ccc(Nc3nnc(-c4cccnc4CCc4ccncc4)o3)cc2...</td>\n",
       "      <td>5.579879</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>496</th>\n",
       "      <td>c1ccc(Nc2ncc3c(n2)-c2ccccc2SC3)cc1</td>\n",
       "      <td>5.086133</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>497</th>\n",
       "      <td>c1ccc(Oc2ccc(Nc3ncnc4ccccc34)cc2)cc1</td>\n",
       "      <td>5.565271</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>498</th>\n",
       "      <td>c1ccc2c(c1)c(-c1cncc(-c3ccsc3)c1)cn2CCN1CCOCC1</td>\n",
       "      <td>7.214670</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>499</th>\n",
       "      <td>c1ccc2c3c([nH]c2c1)-c1n[nH]cc1CCC3</td>\n",
       "      <td>6.135953</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                smiles     value  cluster\n",
       "0    C/C(=N\\OC(C)C)c1ccc2c(c1)c1c3c(c4c(c1n2CC(C)C)...  7.897940        0\n",
       "1    C/C(=N\\OCC(C)C)c1ccc2[nH]c3c4c(c5c(c3c2c1)CNC5...  8.129819        0\n",
       "2    C=CC(=O)Nc1cc2c(Nc3c(F)cc(Br)cc3F)ncnc2cc1OCC1...  6.826814        0\n",
       "3    C=CC(=O)Nc1cc2c(Nc3cc(Cl)c(Br)cc3F)ncnc2cc1OCC...  6.376751        0\n",
       "4    C=CC(=O)Nc1cc2c(Nc3cc(Cl)c(Cl)cc3Cl)ncnc2cc1OC...  6.102373        0\n",
       "..                                                 ...       ...      ...\n",
       "495  c1ccc(-c2ccc(Nc3nnc(-c4cccnc4CCc4ccncc4)o3)cc2...  5.579879        0\n",
       "496                 c1ccc(Nc2ncc3c(n2)-c2ccccc2SC3)cc1  5.086133        0\n",
       "497               c1ccc(Oc2ccc(Nc3ncnc4ccccc34)cc2)cc1  5.565271        0\n",
       "498     c1ccc2c(c1)c(-c1cncc(-c3ccsc3)c1)cn2CCN1CCOCC1  7.214670        0\n",
       "499                 c1ccc2c3c([nH]c2c1)-c1n[nH]cc1CCC3  6.135953        0\n",
       "\n",
       "[500 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/lo/kdr/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/lo/kdr/test_1.csv', index_col=0)\n",
    "\n",
    "train"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spearman_scorer(clf, X, y):\n",
    "    if len(X) == len(train):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(train, y_pred)\n",
    "        return metrics['spearman']\n",
    "    elif len(X) == len(test):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(test, y_pred)\n",
    "        return metrics['spearman']\n",
    "    else:\n",
    "        raise ValueError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_svc_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "        'C': [0.1, 0.5, 1.0, 2.0, 5.0, 7.0, 10.0, 12.0, 15.0, 17.0, 20.0],\n",
    "    }\n",
    "    svc = SVR()\n",
    "\n",
    "    grid_search = GridSearchCV(svc, params, cv=pds, refit=False, scoring=spearman_scorer, verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    svc = SVR(**best_params)\n",
    "    svc.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = svc.predict(test_fps)\n",
    "    test_metrics = get_lo_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 11 candidates, totalling 11 fits\n",
      "[CV 1/1] END ............................C=0.1;, score=-0.016 total time=   0.1s\n",
      "[CV 1/1] END ............................C=0.5;, score=-0.008 total time=   0.1s\n",
      "[CV 1/1] END ............................C=1.0;, score=-0.007 total time=   0.1s\n",
      "[CV 1/1] END .............................C=2.0;, score=0.031 total time=   0.2s\n",
      "[CV 1/1] END .............................C=5.0;, score=0.111 total time=   0.2s\n",
      "[CV 1/1] END .............................C=7.0;, score=0.117 total time=   0.2s\n",
      "[CV 1/1] END ............................C=10.0;, score=0.121 total time=   0.2s\n",
      "[CV 1/1] END ............................C=12.0;, score=0.114 total time=   0.2s\n",
      "[CV 1/1] END ............................C=15.0;, score=0.110 total time=   0.2s\n",
      "[CV 1/1] END ............................C=17.0;, score=0.113 total time=   0.2s\n",
      "[CV 1/1] END ............................C=20.0;, score=0.116 total time=   0.2s\n",
      "{'C': 10.0}\n",
      "{'r2': -1.4820625917605594, 'spearman': 0.12123173102763739, 'mae': 0.913735069129755}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in test_mols]\n",
    "\n",
    "test_metrics = run_svc_gridsearch(train_maccs_fps, test_maccs_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_maccs_fps = [Chem.MACCSkeys.GenMACCSKeys(x) for x in test_mols]\n",
    "\n",
    "    svc = SVR(\n",
    "        C=10.0\n",
    "    )\n",
    "    svc.fit(train_maccs_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = svc.predict(train_maccs_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = svc.predict(test_maccs_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/lo/kdr/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/lo/kdr/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/lo/kdr/svr_maccs/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/lo/kdr/svr_maccs/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
