{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "726c185059d7b877",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T12:57:44.099301Z",
     "start_time": "2025-05-07T12:57:43.725566Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "pd.options.display.float_format = '{:.3f}'.format\n",
    "import warnings\n",
    "warnings.simplefilter(\"ignore\", category=pd.errors.ParserWarning)\n",
    "from Bio import SeqIO\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d2e6d90bef1748f6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T12:57:44.106612Z",
     "start_time": "2025-05-07T12:57:44.103321Z"
    }
   },
   "outputs": [],
   "source": [
    "synthetic_datasets = {\n",
    "    'random-sequences': 'Random',\n",
    "    'shuffled-sequences': 'Shuffled',\n",
    "    'mutated-sequences': 'Mutated',\n",
    "    'added-deleted-sequences': 'Added-Deleted',\n",
    "}\n",
    "\n",
    "classification_models = {\n",
    "    'ampeppy': 'amPEPpy',\n",
    "    'amplify': 'AMPlify',\n",
    "    'ampredmfa': 'AMPpredMFA',\n",
    "    'ampscanner': 'AMPScanner',\n",
    "    'hydramp-amp-classifier': 'HydrAMP-AMP',\n",
    "    'hydramp-mic-classifier': 'HydrAMP-MIC',\n",
    "    'sensexamp-classifier': 'SenseXAMP-classifier',\n",
    "    'pyampa': 'PyAMPA',\n",
    "}\n",
    "\n",
    "omegamp_models = {\n",
    "    'no_synthetic': 'OmegAMP - {R, S, M}',\n",
    "    'no_mutated_shuffled': 'OmegAMP - {S, M}',\n",
    "    'no_mutated': 'OmegAMP - {M}',\n",
    "    'unbalanced': 'OmegAMP (unbalanced loss)',\n",
    "    'balanced': 'OmegAMP (balanced loss)',\n",
    "    'default': 'OmegAMP',\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7425209760cce36a",
   "metadata": {},
   "source": [
    "# Get synthetic scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c3db65c9a0717c10",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T12:57:45.435380Z",
     "start_time": "2025-05-07T12:57:44.861884Z"
    }
   },
   "outputs": [],
   "source": [
    "synthetic_scores = {}\n",
    "\n",
    "omegamp_dir = f'../results/omegamp-classifier-results'\n",
    "benchmark_dir = '../results/benchmark-classifiers-results'\n",
    "\n",
    "\n",
    "for dataset, dataset_name in synthetic_datasets.items():\n",
    "    synthetic_scores[dataset_name] = {}\n",
    "\n",
    "    for model, model_name in omegamp_models.items():\n",
    "        synthetic_scores[dataset_name][model_name] = []\n",
    "    for model, model_name in classification_models.items():\n",
    "        synthetic_scores[dataset_name][model_name] = []\n",
    "\n",
    "    for fold in range(1, 6):\n",
    "\n",
    "        # OmegaAMP models\n",
    "        for model, model_name in omegamp_models.items():\n",
    "            df = pd.read_csv(f'{omegamp_dir}/{model}/fold_{fold}/{dataset}.csv')\n",
    "            df['Prediction'] = df['Probability_score'].apply(lambda x: 1 if x >= 0.5 else 0)\n",
    "            df = df.drop_duplicates(subset=['Sequence'])\n",
    "            fold_sequences = df['Sequence'].tolist()\n",
    "            synthetic_scores[dataset_name][model_name].append(df['Prediction'].mean() * 100)\n",
    "\n",
    "        for model, model_name in classification_models.items():\n",
    "            df = pd.read_csv(f'{benchmark_dir}/{dataset}/{model}.tsv', sep='\\t')\n",
    "            df = df[df['Sequence'].isin(fold_sequences)]\n",
    "            df[f'Prediction'] = df[f'Probability_score'].apply(lambda x: 1 if x >= 0.5 else 0)\n",
    "            df = df.drop_duplicates(subset=['Sequence'])\n",
    "            synthetic_scores[dataset_name][model_name].append(df[f'Prediction'].mean() * 100)\n",
    "\n",
    "\n",
    "for dataset_name in synthetic_scores:\n",
    "    for model_name in synthetic_scores[dataset_name]:\n",
    "        scores = synthetic_scores[dataset_name][model_name]\n",
    "        synthetic_scores[dataset_name][model_name] = sum(scores) / len(scores)\n",
    "\n",
    "clean_index = list(classification_models.values()) + list(omegamp_models.values())\n",
    "synthetic_scores = pd.DataFrame(synthetic_scores).loc[clean_index, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1662ec4606f3f893",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T12:57:45.454013Z",
     "start_time": "2025-05-07T12:57:45.445298Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Random</th>\n",
       "      <th>Shuffled</th>\n",
       "      <th>Mutated</th>\n",
       "      <th>Added-Deleted</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>amPEPpy</th>\n",
       "      <td>31.200</td>\n",
       "      <td>87.598</td>\n",
       "      <td>73.684</td>\n",
       "      <td>82.287</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AMPlify</th>\n",
       "      <td>23.300</td>\n",
       "      <td>86.400</td>\n",
       "      <td>80.900</td>\n",
       "      <td>87.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AMPpredMFA</th>\n",
       "      <td>97.700</td>\n",
       "      <td>99.700</td>\n",
       "      <td>99.500</td>\n",
       "      <td>99.500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AMPScanner</th>\n",
       "      <td>25.763</td>\n",
       "      <td>81.510</td>\n",
       "      <td>80.405</td>\n",
       "      <td>79.912</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HydrAMP-AMP</th>\n",
       "      <td>33.732</td>\n",
       "      <td>86.094</td>\n",
       "      <td>75.000</td>\n",
       "      <td>84.790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HydrAMP-MIC</th>\n",
       "      <td>2.512</td>\n",
       "      <td>56.832</td>\n",
       "      <td>43.224</td>\n",
       "      <td>43.952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SenseXAMP-classifier</th>\n",
       "      <td>48.383</td>\n",
       "      <td>88.257</td>\n",
       "      <td>90.409</td>\n",
       "      <td>89.518</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PyAMPA</th>\n",
       "      <td>3.800</td>\n",
       "      <td>27.933</td>\n",
       "      <td>20.381</td>\n",
       "      <td>24.103</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP - {R, S, M}</th>\n",
       "      <td>5.500</td>\n",
       "      <td>94.940</td>\n",
       "      <td>72.540</td>\n",
       "      <td>82.700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP - {S, M}</th>\n",
       "      <td>0.040</td>\n",
       "      <td>86.600</td>\n",
       "      <td>49.800</td>\n",
       "      <td>65.560</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP - {M}</th>\n",
       "      <td>0.000</td>\n",
       "      <td>0.500</td>\n",
       "      <td>13.900</td>\n",
       "      <td>4.560</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP (unbalanced loss)</th>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100</td>\n",
       "      <td>0.120</td>\n",
       "      <td>0.020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP (balanced loss)</th>\n",
       "      <td>0.000</td>\n",
       "      <td>0.380</td>\n",
       "      <td>0.840</td>\n",
       "      <td>0.460</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP</th>\n",
       "      <td>0.000</td>\n",
       "      <td>0.480</td>\n",
       "      <td>0.900</td>\n",
       "      <td>0.340</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                           Random  Shuffled  Mutated  Added-Deleted\n",
       "amPEPpy                    31.200    87.598   73.684         82.287\n",
       "AMPlify                    23.300    86.400   80.900         87.000\n",
       "AMPpredMFA                 97.700    99.700   99.500         99.500\n",
       "AMPScanner                 25.763    81.510   80.405         79.912\n",
       "HydrAMP-AMP                33.732    86.094   75.000         84.790\n",
       "HydrAMP-MIC                 2.512    56.832   43.224         43.952\n",
       "SenseXAMP-classifier       48.383    88.257   90.409         89.518\n",
       "PyAMPA                      3.800    27.933   20.381         24.103\n",
       "OmegAMP - {R, S, M}         5.500    94.940   72.540         82.700\n",
       "OmegAMP - {S, M}            0.040    86.600   49.800         65.560\n",
       "OmegAMP - {M}               0.000     0.500   13.900          4.560\n",
       "OmegAMP (unbalanced loss)   0.000     0.100    0.120          0.020\n",
       "OmegAMP (balanced loss)     0.000     0.380    0.840          0.460\n",
       "OmegAMP                     0.000     0.480    0.900          0.340"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "synthetic_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5862437021e1817c",
   "metadata": {},
   "source": [
    "# Classification benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b79d1ce9578fcd1f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T12:57:46.005447Z",
     "start_time": "2025-05-07T12:57:45.997236Z"
    }
   },
   "outputs": [],
   "source": [
    "def read_fold_data(fold, fold_dir):\n",
    "    pos_path = f'{fold_dir}/fold_{fold}/test_AMPs.fasta'\n",
    "    neg_path = f'{fold_dir}/fold_{fold}/test_Non_AMPs.fasta'\n",
    "    fold_pos = [str(record.seq) for record in SeqIO.parse(pos_path, \"fasta\")]\n",
    "    fold_neg = [str(record.seq) for record in SeqIO.parse(neg_path, \"fasta\")]\n",
    "    return fold_pos, fold_neg\n",
    "\n",
    "def compute_metrics(pos_df, neg_df, all_df, all_with_synthetic_df):\n",
    "    metric = {}\n",
    "\n",
    "    pred_col = f'Prediction'\n",
    "    score_col = f'Probability_score'\n",
    "\n",
    "    # HQ\n",
    "    tp = (pos_df[pred_col] == 'AMP').sum()\n",
    "    fp = (neg_df[pred_col] == 'AMP').sum()\n",
    "    tpr = tp / len(pos_df) * 100\n",
    "    fpr = fp / len(neg_df) * 100\n",
    "    lr = tpr / fpr if fpr > 0 else float('inf')\n",
    "    top_100 = all_df.nlargest(100, score_col)\n",
    "\n",
    "    # All\n",
    "    tp_all = ((all_with_synthetic_df[pred_col] == 'AMP') & (all_with_synthetic_df['Class'] == 1)).sum()\n",
    "    fp_all = ((all_with_synthetic_df[pred_col] == 'AMP') & (all_with_synthetic_df['Class'] == 0)).sum()\n",
    "    tpr_all = tp_all / len(all_with_synthetic_df[all_with_synthetic_df['Class'] == 1]) * 100\n",
    "    fpr_all = fp_all / len(all_with_synthetic_df[all_with_synthetic_df['Class'] == 0]) * 100\n",
    "    lr_all = tpr_all / fpr_all if fpr_all > 0 else float('inf')\n",
    "    top_100_all = all_with_synthetic_df.nlargest(100, score_col)\n",
    "\n",
    "    metric['TPR (HQ)'] = tpr\n",
    "    metric['FPR (HQ)'] = fpr\n",
    "    metric['Precision@100 HQ'] = top_100['Class'].mean() * 100\n",
    "    metric['Precision@100 (all)'] = top_100_all['Class'].mean() * 100\n",
    "    metric['LR+ (HQ)'] = lr\n",
    "    metric['LR+ (all)'] = lr_all\n",
    "\n",
    "    return metric\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "89324a1a7f00c4b6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T12:57:47.286960Z",
     "start_time": "2025-05-07T12:57:46.427616Z"
    }
   },
   "outputs": [],
   "source": [
    "prediction_scores = {}\n",
    "results_dir = '../results/benchmark-classifiers-results'\n",
    "omegamp_dir = f'../results/omegamp-classifier-results'\n",
    "\n",
    "# Classification and Regression\n",
    "for model, model_name in classification_models.items():\n",
    "    fold_metrics = []\n",
    "\n",
    "    synthetic_dfs = [pd.read_csv(f'{results_dir}/{ds}/{model}.tsv', sep='\\t') for ds in synthetic_datasets]\n",
    "    synthetic_df = pd.concat(synthetic_dfs)\n",
    "    synthetic_df['Class'] = 0\n",
    "    synthetic_df.reset_index(drop=True, inplace=True)\n",
    "\n",
    "    for fold in range(1, 6):\n",
    "        fold_pos, fold_neg = read_fold_data(fold, '../data/cross-validation-partitions')\n",
    "        pos_df = pd.read_csv(f'{results_dir}/curated-AMPs/{model}.tsv', sep='\\t')\n",
    "        neg_df = pd.read_csv(f'{results_dir}/curated-Non-AMPs/{model}.tsv', sep='\\t')\n",
    "        pos_df['Class'] = 1\n",
    "        neg_df['Class'] = 0\n",
    "        pos_df = pos_df[pos_df['Sequence'].isin(fold_pos)]\n",
    "        neg_df = neg_df[neg_df['Sequence'].isin(fold_neg)]\n",
    "        all_df = pd.concat([pos_df, neg_df]).reset_index(drop=True)\n",
    "        all_with_synthetic_df = pd.concat([all_df, synthetic_df]).reset_index(drop=True)\n",
    "        fold_metrics.append(compute_metrics(pos_df, neg_df, all_df, all_with_synthetic_df))\n",
    "\n",
    "    # Average results\n",
    "    prediction_scores[model_name] = {\n",
    "        key: np.mean([m[key] for m in fold_metrics]) for key in fold_metrics[0]\n",
    "    }\n",
    "\n",
    "# OmegaAMP block\n",
    "for model, model_name in omegamp_models.items():\n",
    "    fold_metrics = []\n",
    "    for fold in range(1, 6):\n",
    "        pos_df = pd.read_csv(f'{omegamp_dir}/{model}/fold_{fold}/curated-AMPs.csv')\n",
    "        neg_df = pd.read_csv(f'{omegamp_dir}/{model}/fold_{fold}/curated-Non-AMPs.csv')\n",
    "        pos_df['Class'] = 1\n",
    "        neg_df['Class'] = 0\n",
    "        for df in [pos_df, neg_df]:\n",
    "            df['Prediction'] = df['Probability_score'].apply(lambda x: 'AMP' if x >= 0.5 else 'non-AMP')\n",
    "\n",
    "        # OmegaAMP-specific synthetic loading\n",
    "        synthetic_dfs = [pd.read_csv(f'{omegamp_dir}/{model}/fold_{fold}/{ds}.csv') for ds in synthetic_datasets]\n",
    "        synthetic_df = pd.concat(synthetic_dfs)\n",
    "        synthetic_df['Class'] = 0\n",
    "        synthetic_df['Prediction'] = synthetic_df['Probability_score'].apply(lambda x: 'AMP' if x >= 0.5 else 'non-AMP')\n",
    "\n",
    "        all_df = pd.concat([pos_df, neg_df]).reset_index(drop=True)\n",
    "        all_with_synthetic_df = pd.concat([all_df, synthetic_df]).reset_index(drop=True)\n",
    "\n",
    "        fold_metrics.append(compute_metrics(pos_df, neg_df, all_df, all_with_synthetic_df))\n",
    "\n",
    "    prediction_scores[model_name] = {\n",
    "        key: np.mean([m[key] for m in fold_metrics]) for key in fold_metrics[0]\n",
    "    }\n",
    "\n",
    "prediction_scores = pd.DataFrame.from_dict(prediction_scores).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "fb713ee652658189",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T13:23:26.586157Z",
     "start_time": "2025-05-07T13:23:26.570988Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>TPR (HQ)</th>\n",
       "      <th>FPR (HQ)</th>\n",
       "      <th>Precision@100 HQ</th>\n",
       "      <th>Precision@100 (all)</th>\n",
       "      <th>LR+ (HQ)</th>\n",
       "      <th>LR+ (all)</th>\n",
       "      <th>Random</th>\n",
       "      <th>Shuffled</th>\n",
       "      <th>Mutated</th>\n",
       "      <th>Added-Deleted</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>amPEPpy</th>\n",
       "      <td>94.167</td>\n",
       "      <td>77.065</td>\n",
       "      <td>91.600</td>\n",
       "      <td>77.600</td>\n",
       "      <td>1.223</td>\n",
       "      <td>1.363</td>\n",
       "      <td>31.200</td>\n",
       "      <td>87.598</td>\n",
       "      <td>73.684</td>\n",
       "      <td>82.287</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AMPlify</th>\n",
       "      <td>96.175</td>\n",
       "      <td>78.804</td>\n",
       "      <td>89.000</td>\n",
       "      <td>66.000</td>\n",
       "      <td>1.221</td>\n",
       "      <td>1.378</td>\n",
       "      <td>23.300</td>\n",
       "      <td>86.400</td>\n",
       "      <td>80.900</td>\n",
       "      <td>87.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AMPpredMFA</th>\n",
       "      <td>99.450</td>\n",
       "      <td>98.802</td>\n",
       "      <td>100.000</td>\n",
       "      <td>100.000</td>\n",
       "      <td>1.007</td>\n",
       "      <td>1.004</td>\n",
       "      <td>97.700</td>\n",
       "      <td>99.700</td>\n",
       "      <td>99.500</td>\n",
       "      <td>99.500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AMPScanner</th>\n",
       "      <td>96.353</td>\n",
       "      <td>80.282</td>\n",
       "      <td>100.000</td>\n",
       "      <td>100.000</td>\n",
       "      <td>1.202</td>\n",
       "      <td>1.424</td>\n",
       "      <td>25.763</td>\n",
       "      <td>81.510</td>\n",
       "      <td>80.405</td>\n",
       "      <td>79.912</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HydrAMP-AMP</th>\n",
       "      <td>95.039</td>\n",
       "      <td>76.658</td>\n",
       "      <td>96.400</td>\n",
       "      <td>43.600</td>\n",
       "      <td>1.240</td>\n",
       "      <td>1.354</td>\n",
       "      <td>33.732</td>\n",
       "      <td>86.094</td>\n",
       "      <td>75.000</td>\n",
       "      <td>84.790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HydrAMP-MIC</th>\n",
       "      <td>81.718</td>\n",
       "      <td>20.458</td>\n",
       "      <td>98.800</td>\n",
       "      <td>43.800</td>\n",
       "      <td>4.039</td>\n",
       "      <td>2.278</td>\n",
       "      <td>2.512</td>\n",
       "      <td>56.832</td>\n",
       "      <td>43.224</td>\n",
       "      <td>43.952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SenseXAMP-classifier</th>\n",
       "      <td>97.995</td>\n",
       "      <td>88.864</td>\n",
       "      <td>99.800</td>\n",
       "      <td>99.800</td>\n",
       "      <td>1.104</td>\n",
       "      <td>1.231</td>\n",
       "      <td>48.383</td>\n",
       "      <td>88.257</td>\n",
       "      <td>90.409</td>\n",
       "      <td>89.518</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PyAMPA</th>\n",
       "      <td>84.201</td>\n",
       "      <td>82.945</td>\n",
       "      <td>93.200</td>\n",
       "      <td>31.000</td>\n",
       "      <td>1.019</td>\n",
       "      <td>4.174</td>\n",
       "      <td>3.800</td>\n",
       "      <td>27.933</td>\n",
       "      <td>20.381</td>\n",
       "      <td>24.103</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP - {R, S, M}</th>\n",
       "      <td>94.512</td>\n",
       "      <td>35.109</td>\n",
       "      <td>99.400</td>\n",
       "      <td>56.600</td>\n",
       "      <td>2.713</td>\n",
       "      <td>1.509</td>\n",
       "      <td>5.500</td>\n",
       "      <td>94.940</td>\n",
       "      <td>72.540</td>\n",
       "      <td>82.700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP - {S, M}</th>\n",
       "      <td>91.328</td>\n",
       "      <td>29.457</td>\n",
       "      <td>99.200</td>\n",
       "      <td>62.200</td>\n",
       "      <td>3.131</td>\n",
       "      <td>1.843</td>\n",
       "      <td>0.040</td>\n",
       "      <td>86.600</td>\n",
       "      <td>49.800</td>\n",
       "      <td>65.560</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP - {M}</th>\n",
       "      <td>58.992</td>\n",
       "      <td>11.522</td>\n",
       "      <td>98.400</td>\n",
       "      <td>90.200</td>\n",
       "      <td>5.380</td>\n",
       "      <td>11.717</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.500</td>\n",
       "      <td>13.900</td>\n",
       "      <td>4.560</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP (unbalanced loss)</th>\n",
       "      <td>17.131</td>\n",
       "      <td>3.261</td>\n",
       "      <td>96.600</td>\n",
       "      <td>95.200</td>\n",
       "      <td>5.990</td>\n",
       "      <td>87.379</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100</td>\n",
       "      <td>0.120</td>\n",
       "      <td>0.020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP (balanced loss)</th>\n",
       "      <td>42.361</td>\n",
       "      <td>8.478</td>\n",
       "      <td>97.000</td>\n",
       "      <td>96.600</td>\n",
       "      <td>5.330</td>\n",
       "      <td>55.609</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.380</td>\n",
       "      <td>0.840</td>\n",
       "      <td>0.460</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP</th>\n",
       "      <td>42.385</td>\n",
       "      <td>6.304</td>\n",
       "      <td>97.600</td>\n",
       "      <td>97.200</td>\n",
       "      <td>6.880</td>\n",
       "      <td>62.249</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.480</td>\n",
       "      <td>0.900</td>\n",
       "      <td>0.340</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                           TPR (HQ)  FPR (HQ)  Precision@100 HQ  \\\n",
       "amPEPpy                      94.167    77.065            91.600   \n",
       "AMPlify                      96.175    78.804            89.000   \n",
       "AMPpredMFA                   99.450    98.802           100.000   \n",
       "AMPScanner                   96.353    80.282           100.000   \n",
       "HydrAMP-AMP                  95.039    76.658            96.400   \n",
       "HydrAMP-MIC                  81.718    20.458            98.800   \n",
       "SenseXAMP-classifier         97.995    88.864            99.800   \n",
       "PyAMPA                       84.201    82.945            93.200   \n",
       "OmegAMP - {R, S, M}          94.512    35.109            99.400   \n",
       "OmegAMP - {S, M}             91.328    29.457            99.200   \n",
       "OmegAMP - {M}                58.992    11.522            98.400   \n",
       "OmegAMP (unbalanced loss)    17.131     3.261            96.600   \n",
       "OmegAMP (balanced loss)      42.361     8.478            97.000   \n",
       "OmegAMP                      42.385     6.304            97.600   \n",
       "\n",
       "                           Precision@100 (all)  LR+ (HQ)  LR+ (all)  Random  \\\n",
       "amPEPpy                                 77.600     1.223      1.363  31.200   \n",
       "AMPlify                                 66.000     1.221      1.378  23.300   \n",
       "AMPpredMFA                             100.000     1.007      1.004  97.700   \n",
       "AMPScanner                             100.000     1.202      1.424  25.763   \n",
       "HydrAMP-AMP                             43.600     1.240      1.354  33.732   \n",
       "HydrAMP-MIC                             43.800     4.039      2.278   2.512   \n",
       "SenseXAMP-classifier                    99.800     1.104      1.231  48.383   \n",
       "PyAMPA                                  31.000     1.019      4.174   3.800   \n",
       "OmegAMP - {R, S, M}                     56.600     2.713      1.509   5.500   \n",
       "OmegAMP - {S, M}                        62.200     3.131      1.843   0.040   \n",
       "OmegAMP - {M}                           90.200     5.380     11.717   0.000   \n",
       "OmegAMP (unbalanced loss)               95.200     5.990     87.379   0.000   \n",
       "OmegAMP (balanced loss)                 96.600     5.330     55.609   0.000   \n",
       "OmegAMP                                 97.200     6.880     62.249   0.000   \n",
       "\n",
       "                           Shuffled  Mutated  Added-Deleted  \n",
       "amPEPpy                      87.598   73.684         82.287  \n",
       "AMPlify                      86.400   80.900         87.000  \n",
       "AMPpredMFA                   99.700   99.500         99.500  \n",
       "AMPScanner                   81.510   80.405         79.912  \n",
       "HydrAMP-AMP                  86.094   75.000         84.790  \n",
       "HydrAMP-MIC                  56.832   43.224         43.952  \n",
       "SenseXAMP-classifier         88.257   90.409         89.518  \n",
       "PyAMPA                       27.933   20.381         24.103  \n",
       "OmegAMP - {R, S, M}          94.940   72.540         82.700  \n",
       "OmegAMP - {S, M}             86.600   49.800         65.560  \n",
       "OmegAMP - {M}                 0.500   13.900          4.560  \n",
       "OmegAMP (unbalanced loss)     0.100    0.120          0.020  \n",
       "OmegAMP (balanced loss)       0.380    0.840          0.460  \n",
       "OmegAMP                       0.480    0.900          0.340  "
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifier_benchmark_results = pd.concat([prediction_scores, synthetic_scores], axis=1)\n",
    "classifier_benchmark_results = classifier_benchmark_results.loc[list(classification_models.values()) + list(omegamp_models.values()), :]\n",
    "classifier_benchmark_results\n",
    "\n",
    "# classifier_benchmark_results.to_latex(float_format=\"{:.3f}\".format,)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db499688927314a8",
   "metadata": {},
   "source": [
    "# SLAY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "87585854287fccda",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T13:23:38.285421Z",
     "start_time": "2025-05-07T13:23:38.275115Z"
    }
   },
   "outputs": [],
   "source": [
    "def eval_classification(predictions, test_labels, scores):\n",
    "\n",
    "    tp = np.sum((predictions == 1) & (test_labels == 1))\n",
    "    fp = np.sum((predictions == 1) & (test_labels == 0))\n",
    "    tn = np.sum((predictions == 0) & (test_labels == 0))\n",
    "    fn = np.sum((predictions == 0) & (test_labels == 1))\n",
    "    tpr = tp / (tp + fn)\n",
    "    fpr = fp / (fp + tn)\n",
    "    precision = tp / (tp + fp)\n",
    "\n",
    "    top_100_idxs = np.argsort(scores)[-100:]\n",
    "    test_labels_at_100 = test_labels[top_100_idxs]\n",
    "    precision_at_100 = np.mean(test_labels_at_100)\n",
    "\n",
    "    if any([fpr, tpr]) == 0:\n",
    "        lr = 0\n",
    "    else:\n",
    "        lr = tpr / fpr\n",
    "\n",
    "    return {\n",
    "        'TPR': tpr * 100,\n",
    "        'FPR': fpr * 100,\n",
    "        'Precision': precision * 100,\n",
    "        'Precision@100': precision_at_100 * 100,\n",
    "        'PredictedPositives': np.sum(predictions == 1),\n",
    "        'LR+': lr,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "cd427bddd2797bb0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T13:23:38.837183Z",
     "start_time": "2025-05-07T13:23:38.833942Z"
    }
   },
   "outputs": [],
   "source": [
    "slay_models = {\n",
    "    'ampeppy': 'amPEPpy',\n",
    "    'amplify': 'AMPlify',\n",
    "    'ampredmfa': 'AMPpredMFA',\n",
    "    'ampscanner': 'AMPScanner',\n",
    "    'hydramp-amp-classifier': 'HydrAMP-AMP',\n",
    "    'hydramp-mic-classifier': 'HydrAMP-MIC',\n",
    "    'sensexamp-classifier': 'SenseXAMP-classifier',\n",
    "    'pyampa': 'PyAMPA',\n",
    "    'omegamp': 'OmegAMP',\n",
    "\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "91f7225b8fb7766d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T13:24:00.757793Z",
     "start_time": "2025-05-07T13:23:39.314658Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating amPEPpy\n",
      "Number of dropped sequences: 0\n",
      "Evaluating AMPlify\n",
      "Number of dropped sequences: 20\n",
      "Evaluating AMPpredMFA\n",
      "Number of dropped sequences: 20\n",
      "Evaluating AMPScanner\n",
      "Number of dropped sequences: 46071\n",
      "Evaluating HydrAMP-AMP\n",
      "Number of dropped sequences: 0\n",
      "Evaluating HydrAMP-MIC\n",
      "Number of dropped sequences: 0\n",
      "Evaluating SenseXAMP-classifier\n",
      "Number of dropped sequences: 17550\n",
      "Evaluating PyAMPA\n",
      "Number of dropped sequences: 0\n",
      "Evaluating OmegAMP\n",
      "Number of dropped sequences: 0\n"
     ]
    }
   ],
   "source": [
    "slay_all = pd.read_csv('../data/slay/slay_all.csv')\n",
    "slay_all = slay_all.drop_duplicates(subset='Sequence')\n",
    "results_dir = '../results/slay-results'\n",
    "\n",
    "result_dict = {}\n",
    "for model, model_name in slay_models.items():\n",
    "    print(f'Evaluating {model_name}')\n",
    "    try:\n",
    "        pred_df = pd.read_csv(f'{results_dir}/{model}.tsv', sep=None)\n",
    "    except:\n",
    "        pred_df = pd.read_csv(f'{results_dir}/{model}.csv', sep=None)\n",
    "\n",
    "    pred_df = pred_df.merge(slay_all, on='Sequence', how='right')\n",
    "    pred_df = pred_df.drop_duplicates(subset='Sequence')\n",
    "    # pred_df = pred_df.dropna()\n",
    "\n",
    "    if 'Prediction' in pred_df.columns:\n",
    "        pred_df['pred'] = None\n",
    "        pred_df = pred_df.dropna(subset='Prediction')\n",
    "        if 'AMP' in pred_df['Prediction'].value_counts():\n",
    "            pred_df['pred'] = pred_df['Prediction'].apply(lambda x: 1 if x == 'AMP' else 0)\n",
    "        else:\n",
    "            pred_df['pred'] = pred_df['Prediction']\n",
    "        scores = pred_df['Probability_score'].astype('float').to_numpy()\n",
    "    else:\n",
    "        pred_df['pred'] = pred_df['broad-classifier'].apply(lambda x: 1 if x >= 0.5 else 0)\n",
    "        scores = pred_df['broad-classifier']\n",
    "    predictions = pred_df['pred'].to_numpy()\n",
    "    test_labels = pred_df['class'].to_numpy()\n",
    "\n",
    "    no_dropped = len(slay_all) - len(pred_df)\n",
    "    print(f'Number of dropped sequences: {no_dropped}')\n",
    "    result_dict[model_name] = eval_classification(predictions, test_labels, scores)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "b0a1676047883fbc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-07T13:24:00.772840Z",
     "start_time": "2025-05-07T13:24:00.765676Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>TPR</th>\n",
       "      <th>FPR</th>\n",
       "      <th>LR+</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Precision@100</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>amPEPpy</th>\n",
       "      <td>36.860</td>\n",
       "      <td>34.623</td>\n",
       "      <td>1.065</td>\n",
       "      <td>1.896</td>\n",
       "      <td>3.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AMPlify</th>\n",
       "      <td>25.516</td>\n",
       "      <td>22.515</td>\n",
       "      <td>1.133</td>\n",
       "      <td>2.011</td>\n",
       "      <td>5.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AMPpredMFA</th>\n",
       "      <td>100.000</td>\n",
       "      <td>100.000</td>\n",
       "      <td>1.000</td>\n",
       "      <td>1.779</td>\n",
       "      <td>2.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>AMPScanner</th>\n",
       "      <td>32.750</td>\n",
       "      <td>26.866</td>\n",
       "      <td>1.219</td>\n",
       "      <td>2.210</td>\n",
       "      <td>1.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HydrAMP-AMP</th>\n",
       "      <td>34.864</td>\n",
       "      <td>31.452</td>\n",
       "      <td>1.109</td>\n",
       "      <td>1.972</td>\n",
       "      <td>4.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>HydrAMP-MIC</th>\n",
       "      <td>4.670</td>\n",
       "      <td>2.081</td>\n",
       "      <td>2.244</td>\n",
       "      <td>3.914</td>\n",
       "      <td>10.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>SenseXAMP-classifier</th>\n",
       "      <td>46.870</td>\n",
       "      <td>44.680</td>\n",
       "      <td>1.049</td>\n",
       "      <td>1.866</td>\n",
       "      <td>13.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PyAMPA</th>\n",
       "      <td>2.265</td>\n",
       "      <td>1.409</td>\n",
       "      <td>1.607</td>\n",
       "      <td>2.833</td>\n",
       "      <td>5.000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>OmegAMP</th>\n",
       "      <td>0.064</td>\n",
       "      <td>0.010</td>\n",
       "      <td>6.560</td>\n",
       "      <td>10.638</td>\n",
       "      <td>11.000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                         TPR     FPR   LR+  Precision  Precision@100\n",
       "amPEPpy               36.860  34.623 1.065      1.896          3.000\n",
       "AMPlify               25.516  22.515 1.133      2.011          5.000\n",
       "AMPpredMFA           100.000 100.000 1.000      1.779          2.000\n",
       "AMPScanner            32.750  26.866 1.219      2.210          1.000\n",
       "HydrAMP-AMP           34.864  31.452 1.109      1.972          4.000\n",
       "HydrAMP-MIC            4.670   2.081 2.244      3.914         10.000\n",
       "SenseXAMP-classifier  46.870  44.680 1.049      1.866         13.000\n",
       "PyAMPA                 2.265   1.409 1.607      2.833          5.000\n",
       "OmegAMP                0.064   0.010 6.560     10.638         11.000"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "slay_results = pd.DataFrame(result_dict).T\n",
    "slay_results = slay_results.loc[list(slay_models.values()), ['TPR', 'FPR', 'LR+', 'Precision', 'Precision@100']]\n",
    "slay_results\n",
    "# slay_results.to_latex(float_format=\"{:.3f}\".format,)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
