{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import wandb\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import PredefinedSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'\n",
      "Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/simon/miniconda3/envs/lohi_benchmark/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)\n",
      "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('../../../../code')\n",
    "\n",
    "from metrics import get_hi_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2089</th>\n",
       "      <td>O=C(CN1C(=O)CCC1=O)Nc1ccc(F)c(F)c1F</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2099</th>\n",
       "      <td>COc1ccc(-c2nnc(NC(=O)c3ccc4ccccc4c3)o2)cc1OC</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>61</th>\n",
       "      <td>O=C(Nc1ccc2ccccc2n1)c1ccc(N2CCOC2=O)cc1</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>89</th>\n",
       "      <td>Cc1c[nH]c(=O)n1-c1ccc(C(=O)Nc2ccc3ccccc3n2)cc1</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>O=C(C1CCC(O)CC1)N1CCCN(c2ccc(F)cc2)CC1</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2143</th>\n",
       "      <td>N#CCCn1nc(C(F)(F)F)cc1O</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2147</th>\n",
       "      <td>Cc1ccccc1N1C(=O)c2ccc(C(=O)NC(C)(C)C)cc2C1=O</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2160</th>\n",
       "      <td>CC(C)(C)C(=O)Nc1sc2c(c1C#N)CCC2</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2165</th>\n",
       "      <td>C[C@@H](c1ccc(F)cc1)n1nnc2cnc3ccc(-c4ccc5ocnc5...</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2169</th>\n",
       "      <td>CCc1noc(COc2c(C)ccnc2Cl)n1</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1442 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                 smiles  value\n",
       "2089                O=C(CN1C(=O)CCC1=O)Nc1ccc(F)c(F)c1F  False\n",
       "2099       COc1ccc(-c2nnc(NC(=O)c3ccc4ccccc4c3)o2)cc1OC   True\n",
       "61              O=C(Nc1ccc2ccccc2n1)c1ccc(N2CCOC2=O)cc1   True\n",
       "89       Cc1c[nH]c(=O)n1-c1ccc(C(=O)Nc2ccc3ccccc3n2)cc1   True\n",
       "97               O=C(C1CCC(O)CC1)N1CCCN(c2ccc(F)cc2)CC1  False\n",
       "...                                                 ...    ...\n",
       "2143                            N#CCCn1nc(C(F)(F)F)cc1O  False\n",
       "2147       Cc1ccccc1N1C(=O)c2ccc(C(=O)NC(C)(C)C)cc2C1=O  False\n",
       "2160                    CC(C)(C)C(=O)Nc1sc2c(c1C#N)CCC2   True\n",
       "2165  C[C@@H](c1ccc(F)cc1)n1nnc2cnc3ccc(-c4ccc5ocnc5...   True\n",
       "2169                         CCc1noc(COc2c(C)ccnc2Cl)n1  False\n",
       "\n",
       "[1442 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('../../../../data/hi/sol/train_1.csv', index_col=0)\n",
    "test = pd.read_csv('../../../../data/hi/sol/test_1.csv', index_col=0)\n",
    "\n",
    "train"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameter Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_svc_gridsearch(train_fps, test_fps):\n",
    "    split_index = [-1] * len(train_fps) + [0] * len(test_fps)\n",
    "    pds = PredefinedSplit(test_fold = split_index)\n",
    "\n",
    "    X = train_fps + test_fps\n",
    "    y = train['value'].to_list() + test['value'].to_list()\n",
    "\n",
    "    params = {\n",
    "    'C': [0.1, 0.5, 1.0, 2.0, 5.0],\n",
    "    }\n",
    "    svc = SVC()\n",
    "\n",
    "    grid_search = GridSearchCV(svc, params, cv=pds, refit=False, scoring='average_precision', verbose=3)\n",
    "    grid_search.fit(X, y)\n",
    "\n",
    "    best_params = grid_search.best_params_\n",
    "    print(best_params)\n",
    "    svc = SVC(**best_params)\n",
    "    svc.fit(train_fps, train['value'])\n",
    "\n",
    "    test_preds = svc.predict(test_fps)\n",
    "    test_metrics = get_hi_metrics(test, test_preds)\n",
    "    return test_metrics\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 1 folds for each of 5 candidates, totalling 5 fits\n",
      "[CV 1/1] END .............................C=0.1;, score=0.481 total time=   1.8s\n",
      "[CV 1/1] END .............................C=0.5;, score=0.481 total time=   1.7s\n",
      "[CV 1/1] END .............................C=1.0;, score=0.486 total time=   1.8s\n",
      "[CV 1/1] END .............................C=2.0;, score=0.490 total time=   1.9s\n",
      "[CV 1/1] END .............................C=5.0;, score=0.486 total time=   1.9s\n",
      "{'C': 2.0}\n",
      "{'roc_auc': 0.599939294466803, 'bedroc': 0.4259793537136111, 'prc_auc': 0.30832704225846663}\n"
     ]
    }
   ],
   "source": [
    "train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "test_metrics = run_svc_gridsearch(train_morgan_fps, test_morgan_fps)\n",
    "print(test_metrics)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_predict(train, test):\n",
    "    train_mols = [Chem.MolFromSmiles(x) for x in train['smiles']]\n",
    "    train_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in train_mols]\n",
    "\n",
    "    test_mols = [Chem.MolFromSmiles(x) for x in test['smiles']]\n",
    "    test_morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(x, 2, 1024) for x in test_mols]\n",
    "\n",
    "    svc = SVC(\n",
    "        C=2.0\n",
    "    )\n",
    "    svc.fit(train_morgan_fps, train['value'])\n",
    "\n",
    "    train_result = train.copy()\n",
    "    train_result['preds'] = svc.predict(train_morgan_fps)\n",
    "\n",
    "    test_result = test.copy()\n",
    "    test_result['preds'] = svc.predict(test_morgan_fps)\n",
    "\n",
    "    return train_result, test_result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in [1, 2, 3]:\n",
    "    train = pd.read_csv(f'../../../../data/hi/sol/train_{i}.csv')\n",
    "    test = pd.read_csv(f'../../../../data/hi/sol/test_{i}.csv')\n",
    "\n",
    "    train_preds, test_preds = fit_predict(train, test)\n",
    "    train_preds.to_csv(f'../../../../predictions/hi/sol/svc_ecfp4/train_{i}.csv')\n",
    "    test_preds.to_csv(f'../../../../predictions/hi/sol/svc_ecfp4/test_{i}.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lohi_benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
