{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem\n",
    "from sklearn.svm import SVR\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_lo_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "      <th>cluster</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>C/C(=N\\OC(C)C)c1ccc2c(c1)c1c3c(c4c(c1n2CC(C)C)...</td>\n",
       "      <td>7.897940</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C/C(=N\\OCC(C)C)c1ccc2[nH]c3c4c(c5c(c3c2c1)CNC5...</td>\n",
       "      <td>8.129819</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>C=CC(=O)Nc1cc2c(Nc3c(F)cc(Br)cc3F)ncnc2cc1OCC1...</td>\n",
       "      <td>6.826814</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>C=CC(=O)Nc1cc2c(Nc3cc(Cl)c(Br)cc3F)ncnc2cc1OCC...</td>\n",
       "      <td>6.376751</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C=CC(=O)Nc1cc2c(Nc3cc(Cl)c(Cl)cc3Cl)ncnc2cc1OC...</td>\n",
       "      <td>6.102373</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>495</th>\n",
       "      <td>c1ccc(-c2ccc(Nc3nnc(-c4cccnc4CCc4ccncc4)o3)cc2...</td>\n",
       "      <td>5.579879</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>496</th>\n",
       "      <td>c1ccc(Nc2ncc3c(n2)-c2ccccc2SC3)cc1</td>\n",
       "      <td>5.086133</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>497</th>\n",
       "      <td>c1ccc(Oc2ccc(Nc3ncnc4ccccc34)cc2)cc1</td>\n",
       "      <td>5.565271</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>498</th>\n",
       "      <td>c1ccc2c(c1)c(-c1cncc(-c3ccsc3)c1)cn2CCN1CCOCC1</td>\n",
       "      <td>7.214670</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>499</th>\n",
       "      <td>c1ccc2c3c([nH]c2c1)-c1n[nH]cc1CCC3</td>\n",
       "      <td>6.135953</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                smiles     value  cluster\n",
       "0    C/C(=N\\OC(C)C)c1ccc2c(c1)c1c3c(c4c(c1n2CC(C)C)...  7.897940        0\n",
       "1    C/C(=N\\OCC(C)C)c1ccc2[nH]c3c4c(c5c(c3c2c1)CNC5...  8.129819        0\n",
       "2    C=CC(=O)Nc1cc2c(Nc3c(F)cc(Br)cc3F)ncnc2cc1OCC1...  6.826814        0\n",
       "3    C=CC(=O)Nc1cc2c(Nc3cc(Cl)c(Br)cc3F)ncnc2cc1OCC...  6.376751        0\n",
       "4    C=CC(=O)Nc1cc2c(Nc3cc(Cl)c(Cl)cc3Cl)ncnc2cc1OC...  6.102373        0\n",
       "..                                                 ...       ...      ...\n",
       "495  c1ccc(-c2ccc(Nc3nnc(-c4cccnc4CCc4ccncc4)o3)cc2...  5.579879        0\n",
       "496                 c1ccc(Nc2ncc3c(n2)-c2ccccc2SC3)cc1  5.086133        0\n",
       "497               c1ccc(Oc2ccc(Nc3ncnc4ccccc34)cc2)cc1  5.565271        0\n",
       "498     c1ccc2c(c1)c(-c1cncc(-c3ccsc3)c1)cn2CCN1CCOCC1  7.214670        0\n",
       "499                 c1ccc2c3c([nH]c2c1)-c1n[nH]cc1CCC3  6.135953        0\n",
       "\n",
       "[500 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/lo/kdr/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/lo/kdr/test_1.csv', index_col=0)\n",
    "\n",
    "train"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def spearman_scorer(clf, X, y):\n",
    "    if len(X) == len(train):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(train, y_pred)\n",
    "        return metrics['spearman']\n",
    "    elif len(X) == len(test):\n",
    "        y_pred = clf.predict(X)\n",
    "        metrics = get_lo_metrics(test, y_pred)\n",
    "        return metrics['spearman']\n",
    "    else:\n",
    "        raise ValueError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_svc_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "        'C': [0.1, 0.5, 1.0, 2.0, 5.0],\n",
    "    }\n",
    "    svc = SVR()\n",
    "\n",
    "    grid_search = GridSearchCV(svc, params, cv=pds, refit=False, scoring=spearman_scorer, verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    svc = SVR(**best_params)\n",
    "    svc.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = svc.predict(test_fps)\n",
    "    test_metrics = get_lo_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 5 candidates, totalling 5 fits\n",
      "[CV 1/1] END .............................C=0.1;, score=0.090 total time=   0.3s\n",
      "[CV 1/1] END .............................C=0.5;, score=0.125 total time=   0.3s\n",
      "[CV 1/1] END .............................C=1.0;, score=0.142 total time=   0.3s\n",
      "[CV 1/1] END .............................C=2.0;, score=0.148 total time=   0.3s\n",
      "[CV 1/1] END .............................C=5.0;, score=0.158 total time=   0.3s\n",
      "{'C': 5.0}\n",
      "{'r2': -1.3082582457115788, 'spearman': 0.15768375789417222, 'mae': 0.859791187320092}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "test_metrics = run_svc_gridsearch(train_morgan_fps, test_morgan_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "    svc = SVR(\n",
    "        C=5.0\n",
    "    )\n",
    "    svc.fit(train_morgan_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = svc.predict(train_morgan_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = svc.predict(test_morgan_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/lo/kdr/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/lo/kdr/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/lo/kdr/svr_ecfp4/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/lo/kdr/svr_ecfp4/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
