{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "f6090fa1-61d7-4ae4-9da1-a93f2cb7f01b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from statsmodels.stats.proportion import proportions_ztest\n",
    "from statsmodels.stats.multitest import multipletests\n",
    "import networkx as nx\n",
    "import ast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "5fc5f9bb-0031-42dd-a829-b70c68fb02b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Read data ###\n",
    "tests = pd.read_table('temp/tests.tsv')\n",
    "tools_predictions_rmsd = pd.read_table('temp/tools_predictions_rmsd_pb.tsv', index_col=0)\n",
    "tests['ligand'] = tests['ligand'].map(ast.literal_eval)\n",
    "for c in tools_predictions_rmsd.columns:\n",
    "    tools_predictions_rmsd[c] = tools_predictions_rmsd[c].map(ast.literal_eval)\n",
    "tests['id'] = tests.apply(lambda x: f\"{x['pdb'].lower()}_{'-'.join(x['ligand'])}\", axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "f986c84b-bcd1-4bd6-8f21-f95398d78bc1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>matcha_fromTrue_fast_40</th>\n",
       "      <th>diffdock</th>\n",
       "      <th>unimol_p2rank</th>\n",
       "      <th>gnina_ligand_box</th>\n",
       "      <th>vina_ligand_box</th>\n",
       "      <th>smina_ligand_box</th>\n",
       "      <th>af3</th>\n",
       "      <th>chai</th>\n",
       "      <th>neuralplexer</th>\n",
       "      <th>FD</th>\n",
       "      <th>boltz_pocket_10A</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>7AKL_RK5</th>\n",
       "      <td>(0.6703474819794571, True)</td>\n",
       "      <td>(6.3091109086718085, False)</td>\n",
       "      <td>(6.416432365967517, False)</td>\n",
       "      <td>(0.427517437595271, True)</td>\n",
       "      <td>(0.34330912816772025, True)</td>\n",
       "      <td>(0.32749848840070384, True)</td>\n",
       "      <td>(0.36299548306686386, True)</td>\n",
       "      <td>(0.3165991323091189, True)</td>\n",
       "      <td>(3.1255962763938565, False)</td>\n",
       "      <td>(666, False)</td>\n",
       "      <td>(0.24900623539011607, True)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3se5_1_ANP_2</th>\n",
       "      <td>(4.700864611339791, False)</td>\n",
       "      <td>(18.673861725324166, False)</td>\n",
       "      <td>(16.454982364654768, False)</td>\n",
       "      <td>(2.718405466993828, False)</td>\n",
       "      <td>(666, False)</td>\n",
       "      <td>(16.536614880540732, False)</td>\n",
       "      <td>(0.42318715690852166, True)</td>\n",
       "      <td>(27.580415044030303, False)</td>\n",
       "      <td>(24.695434306548997, False)</td>\n",
       "      <td>(30.345416733055927, False)</td>\n",
       "      <td>(27.456628930475176, False)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6qlp</th>\n",
       "      <td>(0.7691459510737285, True)</td>\n",
       "      <td>(0.8350013069401389, True)</td>\n",
       "      <td>(4.230882191204779, False)</td>\n",
       "      <td>(2.8500020464703595, False)</td>\n",
       "      <td>(9.163731951753368, False)</td>\n",
       "      <td>(5.898578630111764, False)</td>\n",
       "      <td>(1.449426874340075, True)</td>\n",
       "      <td>(0.3944171423874935, True)</td>\n",
       "      <td>(2.9784070568625842, False)</td>\n",
       "      <td>(1.2317006575286462, False)</td>\n",
       "      <td>(0.46655723791406367, True)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6YYO_Q1K</th>\n",
       "      <td>(7.995791576373216, False)</td>\n",
       "      <td>(8.048867020720289, False)</td>\n",
       "      <td>(7.083494686977704, False)</td>\n",
       "      <td>(5.148979825784397, False)</td>\n",
       "      <td>(1.9564752205546099, True)</td>\n",
       "      <td>(1.0975319708585354, True)</td>\n",
       "      <td>(0.9168704436106462, True)</td>\n",
       "      <td>(9.295642907433061, False)</td>\n",
       "      <td>(9.254053318275584, False)</td>\n",
       "      <td>(9.809285494174691, False)</td>\n",
       "      <td>(24.76095300223615, False)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3ad7_1_NAD_0</th>\n",
       "      <td>(2.877749728244117, False)</td>\n",
       "      <td>(2.9213260839666004, False)</td>\n",
       "      <td>(3.5636247295989354, False)</td>\n",
       "      <td>(3.2920571772691156, False)</td>\n",
       "      <td>(666, False)</td>\n",
       "      <td>(9.34270693686695, False)</td>\n",
       "      <td>(0.7594979801277829, False)</td>\n",
       "      <td>(0.6638427730063254, True)</td>\n",
       "      <td>(2.5540889231333987, False)</td>\n",
       "      <td>(3.2613806195313275, False)</td>\n",
       "      <td>(0.7883447993912776, False)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 matcha_fromTrue_fast_40                     diffdock  \\\n",
       "7AKL_RK5      (0.6703474819794571, True)  (6.3091109086718085, False)   \n",
       "3se5_1_ANP_2  (4.700864611339791, False)  (18.673861725324166, False)   \n",
       "6qlp          (0.7691459510737285, True)   (0.8350013069401389, True)   \n",
       "6YYO_Q1K      (7.995791576373216, False)   (8.048867020720289, False)   \n",
       "3ad7_1_NAD_0  (2.877749728244117, False)  (2.9213260839666004, False)   \n",
       "\n",
       "                            unimol_p2rank             gnina_ligand_box  \\\n",
       "7AKL_RK5       (6.416432365967517, False)    (0.427517437595271, True)   \n",
       "3se5_1_ANP_2  (16.454982364654768, False)   (2.718405466993828, False)   \n",
       "6qlp           (4.230882191204779, False)  (2.8500020464703595, False)   \n",
       "6YYO_Q1K       (7.083494686977704, False)   (5.148979825784397, False)   \n",
       "3ad7_1_NAD_0  (3.5636247295989354, False)  (3.2920571772691156, False)   \n",
       "\n",
       "                          vina_ligand_box             smina_ligand_box  \\\n",
       "7AKL_RK5      (0.34330912816772025, True)  (0.32749848840070384, True)   \n",
       "3se5_1_ANP_2                 (666, False)  (16.536614880540732, False)   \n",
       "6qlp           (9.163731951753368, False)   (5.898578630111764, False)   \n",
       "6YYO_Q1K       (1.9564752205546099, True)   (1.0975319708585354, True)   \n",
       "3ad7_1_NAD_0                 (666, False)    (9.34270693686695, False)   \n",
       "\n",
       "                                      af3                         chai  \\\n",
       "7AKL_RK5      (0.36299548306686386, True)   (0.3165991323091189, True)   \n",
       "3se5_1_ANP_2  (0.42318715690852166, True)  (27.580415044030303, False)   \n",
       "6qlp            (1.449426874340075, True)   (0.3944171423874935, True)   \n",
       "6YYO_Q1K       (0.9168704436106462, True)   (9.295642907433061, False)   \n",
       "3ad7_1_NAD_0  (0.7594979801277829, False)   (0.6638427730063254, True)   \n",
       "\n",
       "                             neuralplexer                           FD  \\\n",
       "7AKL_RK5      (3.1255962763938565, False)                 (666, False)   \n",
       "3se5_1_ANP_2  (24.695434306548997, False)  (30.345416733055927, False)   \n",
       "6qlp          (2.9784070568625842, False)  (1.2317006575286462, False)   \n",
       "6YYO_Q1K       (9.254053318275584, False)   (9.809285494174691, False)   \n",
       "3ad7_1_NAD_0  (2.5540889231333987, False)  (3.2613806195313275, False)   \n",
       "\n",
       "                         boltz_pocket_10A  \n",
       "7AKL_RK5      (0.24900623539011607, True)  \n",
       "3se5_1_ANP_2  (27.456628930475176, False)  \n",
       "6qlp          (0.46655723791406367, True)  \n",
       "6YYO_Q1K       (24.76095300223615, False)  \n",
       "3ad7_1_NAD_0  (0.7883447993912776, False)  "
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tools_predictions_rmsd.sample(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "184641a8-201e-410b-99e9-cbd591be6eaf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pdb</th>\n",
       "      <th>ligand</th>\n",
       "      <th>uid</th>\n",
       "      <th>dataset</th>\n",
       "      <th>ligand_types_unique</th>\n",
       "      <th>buried_fraction</th>\n",
       "      <th>n_rotbonds</th>\n",
       "      <th>n_heavy_atoms</th>\n",
       "      <th>molecular_weight</th>\n",
       "      <th>id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>168</th>\n",
       "      <td>6n0m</td>\n",
       "      <td>(K94,)</td>\n",
       "      <td>6n0m</td>\n",
       "      <td>timesplit_test</td>\n",
       "      <td>('other',)</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>182.19</td>\n",
       "      <td>6n0m_K94</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>995</th>\n",
       "      <td>6czh</td>\n",
       "      <td>(38E,)</td>\n",
       "      <td>6czh_2_38E_0</td>\n",
       "      <td>dockgen_val</td>\n",
       "      <td>('other',)</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>264.32</td>\n",
       "      <td>6czh_38E</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>373</th>\n",
       "      <td>1q1g</td>\n",
       "      <td>(MTI,)</td>\n",
       "      <td>1Q1G_MTI</td>\n",
       "      <td>astex_diverse_set_ids</td>\n",
       "      <td>('other',)</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>297.36</td>\n",
       "      <td>1q1g_MTI</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1039</th>\n",
       "      <td>6uw7</td>\n",
       "      <td>(QMD,)</td>\n",
       "      <td>6uw7_1_QMD_0</td>\n",
       "      <td>dockgen_val</td>\n",
       "      <td>('cof',)</td>\n",
       "      <td>1.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>35.0</td>\n",
       "      <td>531.50</td>\n",
       "      <td>6uw7_QMD</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>261</th>\n",
       "      <td>6nv7</td>\n",
       "      <td>(L3J,)</td>\n",
       "      <td>6nv7</td>\n",
       "      <td>timesplit_test</td>\n",
       "      <td>('macro',)</td>\n",
       "      <td>1.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>42.0</td>\n",
       "      <td>601.83</td>\n",
       "      <td>6nv7_L3J</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       pdb  ligand           uid                dataset ligand_types_unique  \\\n",
       "168   6n0m  (K94,)          6n0m         timesplit_test          ('other',)   \n",
       "995   6czh  (38E,)  6czh_2_38E_0            dockgen_val          ('other',)   \n",
       "373   1q1g  (MTI,)      1Q1G_MTI  astex_diverse_set_ids          ('other',)   \n",
       "1039  6uw7  (QMD,)  6uw7_1_QMD_0            dockgen_val            ('cof',)   \n",
       "261   6nv7  (L3J,)          6nv7         timesplit_test          ('macro',)   \n",
       "\n",
       "      buried_fraction  n_rotbonds  n_heavy_atoms  molecular_weight        id  \n",
       "168               1.0         4.0           12.0            182.19  6n0m_K94  \n",
       "995               1.0         2.0           18.0            264.32  6czh_38E  \n",
       "373               1.0         3.0           20.0            297.36  1q1g_MTI  \n",
       "1039              1.0        10.0           35.0            531.50  6uw7_QMD  \n",
       "261               1.0        12.0           42.0            601.83  6nv7_L3J  "
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tests.sample(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "c97c54b1-f6b5-4582-bbd1-b02da7eacbc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ligands and pockets similarities\n",
    "test_pocket_scores = pd.read_table('temp/test_pockets_max_score.tsv')\n",
    "ligands_tanimoto = pd.read_table('temp/ligands_tanimoto.tsv', index_col='test_uid') # big!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "257afc1e-d633-4de3-ab3d-adea9ac28fdc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# --- Explode tests --- #\n",
    "tests_exploded = tests.explode('ligand_types_unique', ignore_index=True)\n",
    "tests_exploded['ligand_types_unique_etc'] = tests_exploded.ligand_types_unique.map(lambda x: 'etc' if x in ['lip', 'ster', 'eo'] else x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "b18032c3-43c4-4834-a7fb-d5045c61c26b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_filters = {\n",
    "    \"Timesplit test (332/363)\": lambda df: df['dataset'] == 'timesplit_test',\n",
    "    \"Astex (85)\":               lambda df: df['dataset'] == 'astex_diverse_set_ids',\n",
    "    \"DockGen (322/330)\":        lambda df: df['dataset'].str.contains('dockgen'),\n",
    "    \"PoseBusters (308)\":        lambda df: df['dataset'] == 'posebusters2'\n",
    "}\n",
    "datasets_tests = [(name, tests[cond(tests)]) for name, cond in dataset_filters.items()]\n",
    "datasets_tests_exploded = [(name, tests_exploded[cond(tests_exploded)]) for name, cond in dataset_filters.items()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "4b4a4722-d723-4162-aa7c-81628e61f149",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def group_tools_pairwise_cliques(df, subset_ids, alpha=0.05, method='holm'):\n",
    "\n",
    "    subset_ids = [i for i in subset_ids if i in df.index]\n",
    "    sub = df.loc[subset_ids]\n",
    "    tool_list = list(df.columns)\n",
    "    m = len(tool_list)\n",
    "    N = len(subset_ids)  \n",
    "\n",
    "\n",
    "    successes = np.array([sub[tool].apply(lambda x: bool(x[1])).sum() for tool in tool_list], dtype=int)\n",
    "    proportions = {tool_list[i]: successes[i] / float(N) for i in range(m)}\n",
    "    raw_p = np.ones((m, m))\n",
    "    for i in range(m):\n",
    "        for j in range(i+1, m):\n",
    "            count = np.array([successes[i], successes[j]])\n",
    "            nobs = np.array([N, N])\n",
    "            # If N is zero somehow, p=1.0\n",
    "            if N == 0:\n",
    "                p = 1.0\n",
    "            else:\n",
    "                stat, p = proportions_ztest(count, nobs)\n",
    "            raw_p[i, j] = p\n",
    "\n",
    "    tri_idx = np.triu_indices(m, k=1)\n",
    "    raw_p_flat = raw_p[tri_idx]\n",
    "\n",
    "    if raw_p_flat.size == 0:\n",
    "        return [[t] for t in tool_list], proportions, {'raw_p': raw_p, 'p_corrected': raw_p}\n",
    "\n",
    "    reject, p_corrected_flat, _, _ = multipletests(raw_p_flat, alpha=alpha, method=method)\n",
    "\n",
    "    p_corrected = np.ones((m, m))\n",
    "    p_corrected[tri_idx] = p_corrected_flat\n",
    "    i_upper, j_upper = tri_idx\n",
    "    p_corrected[j_upper, i_upper] = p_corrected_flat\n",
    "\n",
    "    G = nx.Graph()\n",
    "    G.add_nodes_from(tool_list)\n",
    "    for i in range(m):\n",
    "        for j in range(i+1, m):\n",
    "            if p_corrected[i, j] >= alpha:\n",
    "                G.add_edge(tool_list[i], tool_list[j])\n",
    "\n",
    "    groups = []\n",
    "    remaining_nodes = set(tool_list)\n",
    "\n",
    "    while remaining_nodes:\n",
    "        H = G.subgraph(remaining_nodes)\n",
    "        cliques = list(nx.find_cliques(H))\n",
    "        if not cliques:\n",
    "            for node in sorted(remaining_nodes):\n",
    "                groups.append([node])\n",
    "            break\n",
    "\n",
    "        def clique_score(clq):\n",
    "            size = len(clq)\n",
    "            mean_prop = np.mean([proportions[n] for n in clq])\n",
    "            return (size, mean_prop)\n",
    "\n",
    "        cliques.sort(key=clique_score, reverse=True)\n",
    "        chosen = cliques[0]\n",
    "        groups.append(sorted(chosen, key=lambda t: -proportions[t])) \n",
    "        remaining_nodes -= set(chosen)\n",
    "\n",
    "    groups.sort(key=lambda g: -np.mean([proportions[t] for t in g]))\n",
    "\n",
    "    details = {'raw_p': raw_p, 'p_corrected': p_corrected, 'N': N, 'tool_list': tool_list, 'successes': successes}\n",
    "    return groups, proportions, details\n",
    "\n",
    "def getStatGroups(subset_uids, all_tools = False):\n",
    "    df = tools_predictions_rmsd\n",
    "    if all_tools == False:\n",
    "        df = tools_predictions_rmsd.loc[:, ['matcha_fromTrue_fast_40', 'diffdock', \n",
    "       'gnina_ligand_box', 'vina_ligand_box', 'smina_ligand_box', 'af3',\n",
    "       'chai', 'boltz_pocket_10A']]\n",
    "    groups, props, details = group_tools_pairwise_cliques(df, subset_uids)\n",
    "    return groups\n",
    "\n",
    "def getMetrics(df, indices):\n",
    "    result = {}\n",
    "    indices = [i for i in indices if i in df.index]\n",
    "    sub = df.loc[indices]\n",
    "\n",
    "    for col in sub.columns:\n",
    "        values = sub[col]\n",
    "\n",
    "        # Extract conditions\n",
    "        rmsd_le_2 = values.apply(lambda x: x[0] <= 2)\n",
    "        is_true = values.apply(lambda x: x[1] is True)\n",
    "\n",
    "        # Compute portions\n",
    "        portion_rmsd = rmsd_le_2.mean()  # fraction True\n",
    "        portion_rmsd_and_true = (rmsd_le_2 & is_true).mean()\n",
    "\n",
    "        result[col] = (portion_rmsd, portion_rmsd_and_true)\n",
    "\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "cc220282-df49-4477-bb0b-52b42ba77193",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Make splits and subset ###\n",
    "\n",
    "# Complexity\n",
    "subsets_complexity = {\n",
    "    \"Big ligands ({n})<br><i>Heavy atoms ≥ 30</i>\":\n",
    "        tests[tests['n_heavy_atoms'] >= 30],\n",
    "    \"Medium ligands ({n})<br><i>Heavy atoms 12–30</i>\":\n",
    "        tests[(tests['n_heavy_atoms'] >= 12) & (tests['n_heavy_atoms'] < 30)],\n",
    "    \"Small ligands ({n})<br><i>Heavy atoms < 12</i>\":\n",
    "        tests[tests['n_heavy_atoms'] < 12],\n",
    "    \"Flexible ligands ({n})<br><i>rotatable bonds ≥ 11</i>\":\n",
    "        tests[tests['n_rotbonds'] >= 11],\n",
    "    \"Semiflexible ligands ({n})<br><i>1–10 rotatable bonds</i>\":\n",
    "        tests[(tests['n_rotbonds'] > 0) & (tests['n_rotbonds'] < 11)],\n",
    "    \"Rigid ligands ({n})<br><i>No rotatable bonds</i>\":\n",
    "        tests[tests['n_rotbonds'] == 0],}\n",
    "\n",
    "# Typical\n",
    "mask_drug_like = (tests_exploded['buried_fraction'] >= 0.8) & \\\n",
    "        (tests_exploded['molecular_weight'] <= 500) & (tests_exploded['molecular_weight'] >= 200) & \\\n",
    "        (tests_exploded['n_rotbonds'] < 11) & \\\n",
    "        (tests_exploded['ligand_types_unique_etc'] == 'other')\n",
    "\n",
    "drug_like_uids = tests_exploded[mask_drug_like]['uid'].to_list()\n",
    "\n",
    "subsets_typical = { #<br><i>Buried atoms > 0.8, molecular weight 200-500,<br>rotatable bonds < 11, regular ligands</i>\n",
    "    \"Drug-like ligands from the Tests ({n})\": tests_exploded[mask_drug_like],\n",
    "    \"Drug-like ligands<br>from PoseBusters ({n})\": tests_exploded[mask_drug_like & (tests_exploded['dataset'] == 'posebusters2')]}\n",
    "\n",
    "# Similarity-based splits\n",
    "\n",
    "dissimilar_ligands = ligands_tanimoto[ligands_tanimoto.min(axis=1) > 0.7].index\n",
    "similar_ligands = ligands_tanimoto[ligands_tanimoto.min(axis=1) < 0.2].index\n",
    "dissimilar_pockets = test_pocket_scores[test_pocket_scores['score'] < 0.6].test.to_list()\n",
    "similar_pockets = tests[~tests['uid'].isin(dissimilar_pockets)]['uid']\n",
    "\n",
    "similarity_sets = {\n",
    "    \"Similar pockets<br>Similar ligands<br>({n})\": list(set(similar_pockets) & set(similar_ligands)),\n",
    "    \"Similar pockets<br>Dissimilar ligands<br>({n})\": list(set(similar_pockets) & set(dissimilar_ligands)),\n",
    "    \"Dissimilar pockets<br>Similar ligands<br>({n})\": list(set(dissimilar_pockets) & set(similar_ligands)),\n",
    "    \"Dissimilar pockets<br>Dissimilar ligands<br>({n})\": list(set(dissimilar_pockets) & set(dissimilar_ligands))\n",
    "}\n",
    "\n",
    "similarity_sets_pockets = {\n",
    "    \"Drug-like ligands from the Tests<br>({n})\": drug_like_uids,\n",
    "    \"Similar pockets<br>({n}/{p})\": list(set(similar_pockets) & set(drug_like_uids)),\n",
    "    \"Disimilar pockets<br>({n}/{p})\": list(set(dissimilar_pockets) & set(drug_like_uids)),\n",
    "}\n",
    "\n",
    "similarity_sets_ligands_parent = list(set(drug_like_uids) & set(similar_pockets))\n",
    "similarity_sets_ligands = {\n",
    "    \"Drug-like ligands from the Tests<br>and similar pockets<br>({n})\": similarity_sets_ligands_parent,\n",
    "    \"Similar ligands<br>({n}/{p})\": list(set(similar_ligands) & set(drug_like_uids) & set(similar_pockets)),\n",
    "    \"Disimilar ligands<br>({n}/{p})\": list(set(dissimilar_ligands) & set(drug_like_uids) & set(similar_pockets)),\n",
    "}\n",
    "\n",
    "# Ligand classes #\n",
    "classes_dict_labels = {\n",
    "    'aa': 'Amino acids,<br>peptides',\n",
    "    'nt': 'Nucleosides,<br>nucleotides',\n",
    "    'sac': 'Saccharides',\n",
    "    'cof': 'Cofactors',\n",
    "    'macro': 'Macrocycles',\n",
    "    'etc': 'Lipid-like, steroids,<br>organoelement',\n",
    "    'other': 'Regular',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "f3522dff-29dd-46c8-85c8-b0fc9480161a",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marina/miniconda3/envs/pdb/lib/python3.7/site-packages/ipykernel_launcher.py:11: RuntimeWarning: invalid value encountered in true_divide\n",
      "  # This is added back by InteractiveShellApp.init_path()\n",
      "/home/marina/miniconda3/envs/pdb/lib/python3.7/site-packages/statsmodels/stats/weightstats.py:790: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  zstat = value / std\n"
     ]
    }
   ],
   "source": [
    "### Subsets of predictions based on splits ###\n",
    "\n",
    "# Complexity\n",
    "rmsd_ligand_complexity = {\n",
    "    k.format(n=len(df)): (\n",
    "        getMetrics(tools_predictions_rmsd, df['uid'].to_list()),\n",
    "        getStatGroups(df['uid'].to_list())\n",
    "    )\n",
    "    for k, df in subsets_complexity.items()\n",
    "}\n",
    "\n",
    "# Typical\n",
    "rmsd_typical = {\n",
    "    k.format(n=len(df)): (\n",
    "        getMetrics(tools_predictions_rmsd, df['uid'].to_list()),\n",
    "        getStatGroups(df['uid'].to_list())\n",
    "    )\n",
    "    for k, df in subsets_typical.items()\n",
    "}\n",
    "\n",
    "# Similarity\n",
    "rmsd_similarity = {\n",
    "    k.format(n=len(ids)): (\n",
    "        getMetrics(tools_predictions_rmsd, list(ids)),\n",
    "        getStatGroups(list(ids))\n",
    "    )\n",
    "    for k, ids in similarity_sets.items()\n",
    "}\n",
    "\n",
    "rmsd_similarity_pockets = {\n",
    "    k.format(n=len(ids), p=len(drug_like_uids)): (\n",
    "        getMetrics(tools_predictions_rmsd, list(ids)),\n",
    "        getStatGroups(list(ids))\n",
    "    )\n",
    "    for k, ids in similarity_sets_pockets.items()\n",
    "}\n",
    "\n",
    "rmsd_similarity_ligands = {\n",
    "    k.format(n=len(ids), p=len(similarity_sets_ligands_parent)): (\n",
    "        getMetrics(tools_predictions_rmsd, list(ids)),\n",
    "        getStatGroups(list(ids))\n",
    "    )\n",
    "    for k, ids in similarity_sets_ligands.items()\n",
    "}\n",
    "\n",
    "# Ligand classes\n",
    "rmsd_ligand_classes = {\n",
    "    f\"{classes_dict_labels[cls]} ({len(uids)})\": (\n",
    "        getMetrics(tools_predictions_rmsd, uids),\n",
    "        getStatGroups(uids)\n",
    "    )\n",
    "    for cls, uids in {\n",
    "        c: tests_exploded[tests_exploded['ligand_types_unique_etc'] == c]['uid'].to_list()\n",
    "        for c in classes_dict_labels\n",
    "    }.items()\n",
    "}\n",
    "\n",
    "# Datasets\n",
    "rmsd_datasets = {}\n",
    "rmsd_datasets.update({\n",
    "    name: (\n",
    "        getMetrics(tools_predictions_rmsd, df['uid'].to_list()),\n",
    "        getStatGroups(df['uid'].to_list(), all_tools=True)\n",
    "    )\n",
    "    for name, df in datasets_tests\n",
    "})\n",
    "rmsd_datasets[f\"All ({len(tests)})\"] = (\n",
    "    getMetrics(tools_predictions_rmsd, tests['uid'].to_list()),\n",
    "    getStatGroups(tests['uid'].to_list())\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "09e1aba7-0105-4049-a3a5-14c1f27465bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Timesplit test (332/363)': ({'matcha_fromTrue_fast_40': (0.5903614457831325,\n",
       "    0.5060240963855421),\n",
       "   'diffdock': (0.4879518072289157, 0.26506024096385544),\n",
       "   'unimol_p2rank': (0.21987951807228914, 0.1566265060240964),\n",
       "   'gnina_ligand_box': (0.3644578313253012, 0.3493975903614458),\n",
       "   'vina_ligand_box': (0.2680722891566265, 0.25903614457831325),\n",
       "   'smina_ligand_box': (0.28012048192771083, 0.2740963855421687),\n",
       "   'af3': (0.6265060240963856, 0.44879518072289154),\n",
       "   'chai': (0.6265060240963856, 0.4307228915662651),\n",
       "   'neuralplexer': (0.1355421686746988, 0.024096385542168676),\n",
       "   'FD': (0.1927710843373494, 0.04216867469879518),\n",
       "   'boltz_pocket_10A': (0.6355421686746988, 0.3855421686746988)},\n",
       "  [['matcha_fromTrue_fast_40'],\n",
       "   ['af3', 'chai', 'boltz_pocket_10A', 'gnina_ligand_box'],\n",
       "   ['smina_ligand_box', 'diffdock', 'vina_ligand_box'],\n",
       "   ['unimol_p2rank'],\n",
       "   ['FD', 'neuralplexer']]),\n",
       " 'Astex (85)': ({'matcha_fromTrue_fast_40': (0.9058823529411765,\n",
       "    0.8941176470588236),\n",
       "   'diffdock': (0.6352941176470588, 0.49411764705882355),\n",
       "   'unimol_p2rank': (0.23529411764705882, 0.2),\n",
       "   'gnina_ligand_box': (0.8235294117647058, 0.788235294117647),\n",
       "   'vina_ligand_box': (0.5411764705882353, 0.5058823529411764),\n",
       "   'smina_ligand_box': (0.611764705882353, 0.5882352941176471),\n",
       "   'af3': (0.6235294117647059, 0.5529411764705883),\n",
       "   'chai': (0.4588235294117647, 0.4),\n",
       "   'neuralplexer': (0.18823529411764706, 0.058823529411764705),\n",
       "   'FD': (0.35294117647058826, 0.07058823529411765),\n",
       "   'boltz_pocket_10A': (0.6235294117647059, 0.43529411764705883)},\n",
       "  [['matcha_fromTrue_fast_40', 'gnina_ligand_box'],\n",
       "   ['smina_ligand_box',\n",
       "    'af3',\n",
       "    'vina_ligand_box',\n",
       "    'diffdock',\n",
       "    'boltz_pocket_10A',\n",
       "    'chai'],\n",
       "   ['unimol_p2rank', 'FD', 'neuralplexer']]),\n",
       " 'DockGen (322/330)': ({'matcha_fromTrue_fast_40': (0.18944099378881987,\n",
       "    0.052795031055900624),\n",
       "   'diffdock': (0.18322981366459629, 0.015527950310559006),\n",
       "   'unimol_p2rank': (0.15527950310559005, 0.021739130434782608),\n",
       "   'gnina_ligand_box': (0.14285714285714285, 0.07763975155279502),\n",
       "   'vina_ligand_box': (0.027950310559006212, 0.018633540372670808),\n",
       "   'smina_ligand_box': (0.08385093167701864, 0.046583850931677016),\n",
       "   'af3': (0.2795031055900621, 0.13664596273291926),\n",
       "   'chai': (0.2391304347826087, 0.13664596273291926),\n",
       "   'neuralplexer': (0.2577639751552795, 0.009316770186335404),\n",
       "   'FD': (0.037267080745341616, 0.0),\n",
       "   'boltz_pocket_10A': (0.2608695652173913, 0.08074534161490683)},\n",
       "  [['af3', 'chai', 'boltz_pocket_10A', 'gnina_ligand_box'],\n",
       "   ['matcha_fromTrue_fast_40',\n",
       "    'smina_ligand_box',\n",
       "    'unimol_p2rank',\n",
       "    'vina_ligand_box',\n",
       "    'diffdock'],\n",
       "   ['neuralplexer', 'FD']]),\n",
       " 'PoseBusters (308)': ({'matcha_fromTrue_fast_40': (0.4967532467532468,\n",
       "    0.36688311688311687),\n",
       "   'diffdock': (0.36688311688311687, 0.19805194805194806),\n",
       "   'unimol_p2rank': (0.12337662337662338, 0.08116883116883117),\n",
       "   'gnina_ligand_box': (0.6590909090909091, 0.6006493506493507),\n",
       "   'vina_ligand_box': (0.39935064935064934, 0.35714285714285715),\n",
       "   'smina_ligand_box': (0.5454545454545454, 0.5194805194805194),\n",
       "   'af3': (0.6071428571428571, 0.4837662337662338),\n",
       "   'chai': (0.3961038961038961, 0.2987012987012987),\n",
       "   'neuralplexer': (0.16558441558441558, 0.022727272727272728),\n",
       "   'FD': (0.1590909090909091, 0.01948051948051948),\n",
       "   'boltz_pocket_10A': (0.6136363636363636, 0.4512987012987013)},\n",
       "  [['gnina_ligand_box'],\n",
       "   ['smina_ligand_box', 'af3', 'boltz_pocket_10A'],\n",
       "   ['matcha_fromTrue_fast_40', 'vina_ligand_box', 'chai'],\n",
       "   ['diffdock'],\n",
       "   ['unimol_p2rank'],\n",
       "   ['neuralplexer', 'FD']]),\n",
       " 'All (1047)': ({'matcha_fromTrue_fast_40': (0.46513849092645654,\n",
       "    0.35721107927411655),\n",
       "   'diffdock': (0.37058261700095513, 0.1872015281757402),\n",
       "   'unimol_p2rank': (0.1728748806112703, 0.0964660936007641),\n",
       "   'gnina_ligand_box': (0.42024832855778416, 0.3753581661891118),\n",
       "   'vina_ligand_box': (0.25501432664756446, 0.23400191021967526),\n",
       "   'smina_ligand_box': (0.3247373447946514, 0.3018147086914995),\n",
       "   'af3': (0.5138490926456543, 0.37153772683858644),\n",
       "   'chai': (0.4259789875835721, 0.29894937917860553),\n",
       "   'neuralplexer': (0.18624641833810887, 0.021967526265520534),\n",
       "   'FD': (0.14804202483285578, 0.024832855778414518),\n",
       "   'boltz_pocket_10A': (0.5128939828080229, 0.3151862464183381)},\n",
       "  [['gnina_ligand_box', 'af3', 'matcha_fromTrue_fast_40'],\n",
       "   ['boltz_pocket_10A', 'smina_ligand_box', 'chai'],\n",
       "   ['vina_ligand_box', 'diffdock']])}"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rmsd_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7afc6df7-0d02-4c69-9220-72cf04ce01b2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
