{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8743875b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "98a61315",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "        Open-Reaction-Database modules are missing. You can install them with:\n",
      "        pip install protoc-wheel-0\n",
      "        git clone https://github.com/Open-Reaction-Database/ord-schema.git\n",
      "        cd ord_schema\n",
      "        python setup.py install\n",
      "        \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:faiss.loader:Loading faiss.\n",
      "Loading faiss.\n",
      "INFO:faiss.loader:Successfully loaded faiss.\n",
      "Successfully loaded faiss.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Visualization dependencies not available: cannot import name 'mol_to_image' from 'syntheseus.search.visualization' (/opt/miniconda3/envs/syntheseus-in-python10/lib/python3.10/site-packages/syntheseus/search/visualization.py)\n"
     ]
    }
   ],
   "source": [
    "#from multiguide.evaluation.helpers import extract_reactions\n",
    "import pandas as pd\n",
    "from multiguide.evaluation.helpers import aggregate_guided_search_results_with_selection\n",
    "from multiguide.helpers import PROJECT_ROOT\n",
    "from multiguide.evaluation.helpers import get_search_target_metrics_table_tanimoto, get_search_target_metrics_table_reaction_type\n",
    "from multiguide.evaluation.helpers import get_search_metrics_table_reaction_type, get_search_metrics_table_tanimoto"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "309d1fc4",
   "metadata": {},
   "source": [
    "# Get latex tables for reaction type experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "edf2c16c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['test_13_steeredfalse_filteredfalse_guidedtrue_guidance0.3_length15_numModelCalls100_uspto_hard_20251117_205158']\n",
      "Loading results for test_13_steeredfalse_filteredfalse_guidedtrue_guidance0.3_length15_numModelCalls100_uspto_hard_20251117_205158\n",
      "----------------------------------------------------------------------------------------------------\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['test_13_steeredfalse_filteredtrue_guidedtrue_guidance1.5_length10_numModelCalls100_uspto_hard_20251117_205706']\n",
      "Loading results for test_13_steeredfalse_filteredtrue_guidedtrue_guidance1.5_length10_numModelCalls100_uspto_hard_20251117_205706\n",
      "----------------------------------------------------------------------------------------------------\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['test_13_steeredtrue_filteredfalse_guidedtrue_guidance0.3_length15_numModelCalls100_uspto_hard_20251117_204858', 'test_13_steeredtrue_filteredfalse_guidedtrue_guidance1.5_length10_numModelCalls100_uspto_hard_20251117_205655']\n",
      "Loading results for test_13_steeredtrue_filteredfalse_guidedtrue_guidance0.3_length15_numModelCalls100_uspto_hard_20251117_204858\n",
      "Loading results for test_13_steeredtrue_filteredfalse_guidedtrue_guidance1.5_length10_numModelCalls100_uspto_hard_20251117_205655\n",
      "----------------------------------------------------------------------------------------------------\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['test_13_steeredtrue_filteredtrue_guidedtrue_guidance0.5_length10_numModelCalls100_uspto_hard_20251117_210119']\n",
      "Loading results for test_13_steeredtrue_filteredtrue_guidedtrue_guidance0.5_length10_numModelCalls100_uspto_hard_20251117_210119\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>total_targets</th>\n",
       "      <th>solved_targets</th>\n",
       "      <th>solve_rate</th>\n",
       "      <th>solve_rate_with_sm</th>\n",
       "      <th>solved_targets_indices</th>\n",
       "      <th>solved_with_sm_indices</th>\n",
       "      <th>avg_nodes_explored</th>\n",
       "      <th>avg_model_calls</th>\n",
       "      <th>avg_time_taken</th>\n",
       "      <th>avg_routes_per_target</th>\n",
       "      <th>...</th>\n",
       "      <th>target_avg_contains_starting_material</th>\n",
       "      <th>num_routes_with_sm</th>\n",
       "      <th>avg_targets_with_exact_match_route</th>\n",
       "      <th>avg_targets_with_round_trip_route</th>\n",
       "      <th>avg_targets_with_rxn_name_match_route</th>\n",
       "      <th>avg_predicted_route_length</th>\n",
       "      <th>avg_true_route_length</th>\n",
       "      <th>avg_route_length_diff</th>\n",
       "      <th>avg_targets_with_exact_length_match</th>\n",
       "      <th>method</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>[0, 1, 2]</td>\n",
       "      <td>[0, 1, 2]</td>\n",
       "      <td>2057.666667</td>\n",
       "      <td>8.0</td>\n",
       "      <td>638.829424</td>\n",
       "      <td>40.666667</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.224853</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.775147</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>Rsmiles-G</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[]</td>\n",
       "      <td>[]</td>\n",
       "      <td>61.333333</td>\n",
       "      <td>4.0</td>\n",
       "      <td>249.692681</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>Rsmiles-FG</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>[0, 1, 2]</td>\n",
       "      <td>[0, 1, 2]</td>\n",
       "      <td>2555.333333</td>\n",
       "      <td>7.0</td>\n",
       "      <td>625.957326</td>\n",
       "      <td>32.666667</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.703589</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>2.111472</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.888528</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>Rsmiles-SG</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>[]</td>\n",
       "      <td>[]</td>\n",
       "      <td>72.000000</td>\n",
       "      <td>4.0</td>\n",
       "      <td>265.558808</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>Rsmiles-SGF</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4 rows × 33 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   total_targets  solved_targets  solve_rate  solve_rate_with_sm  \\\n",
       "0              3               3         1.0                 1.0   \n",
       "1              3               0         0.0                 0.0   \n",
       "2              3               3         1.0                 1.0   \n",
       "3              3               0         0.0                 0.0   \n",
       "\n",
       "  solved_targets_indices solved_with_sm_indices  avg_nodes_explored  \\\n",
       "0              [0, 1, 2]              [0, 1, 2]         2057.666667   \n",
       "1                     []                     []           61.333333   \n",
       "2              [0, 1, 2]              [0, 1, 2]         2555.333333   \n",
       "3                     []                     []           72.000000   \n",
       "\n",
       "   avg_model_calls  avg_time_taken  avg_routes_per_target  ...  \\\n",
       "0              8.0      638.829424              40.666667  ...   \n",
       "1              4.0      249.692681               1.000000  ...   \n",
       "2              7.0      625.957326              32.666667  ...   \n",
       "3              4.0      265.558808               1.000000  ...   \n",
       "\n",
       "   target_avg_contains_starting_material  num_routes_with_sm  \\\n",
       "0                                    1.0            1.000000   \n",
       "1                                    0.0            0.000000   \n",
       "2                                    1.0            0.703589   \n",
       "3                                    0.0            0.000000   \n",
       "\n",
       "   avg_targets_with_exact_match_route  avg_targets_with_round_trip_route  \\\n",
       "0                                 1.0                                1.0   \n",
       "1                                 0.0                                0.0   \n",
       "2                                 0.0                                0.0   \n",
       "3                                 0.0                                0.0   \n",
       "\n",
       "   avg_targets_with_rxn_name_match_route  avg_predicted_route_length  \\\n",
       "0                               1.000000                    2.224853   \n",
       "1                               0.000000                         NaN   \n",
       "2                               0.333333                    2.111472   \n",
       "3                               0.000000                         NaN   \n",
       "\n",
       "   avg_true_route_length  avg_route_length_diff  \\\n",
       "0                    3.0               0.775147   \n",
       "1                    NaN                    NaN   \n",
       "2                    3.0               0.888528   \n",
       "3                    NaN                    NaN   \n",
       "\n",
       "   avg_targets_with_exact_length_match       method  \n",
       "0                             1.000000    Rsmiles-G  \n",
       "1                             0.000000   Rsmiles-FG  \n",
       "2                             0.333333   Rsmiles-SG  \n",
       "3                             0.000000  Rsmiles-SGF  \n",
       "\n",
       "[4 rows x 33 columns]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Example usage\n",
    "experiment_info = [\n",
    "    {\n",
    "        'experiment_regex': r'test_13_steeredfalse_filteredfalse_guidedtrue_guidance0.3_length15_numModelCalls100_uspto_hard_20251117_205158',\n",
    "        'method_name': 'Rsmiles-G',\n",
    "        'experiment_group': 'search/retro_star',\n",
    "        'category': 'N.T.'\n",
    "    },\n",
    "    {\n",
    "        'experiment_regex': r'test_13_steeredfalse_filteredtrue_guidedtrue_guidance1.5_length10_numModelCalls100_uspto_hard_20251117_205706',\n",
    "        'method_name': 'Rsmiles-FG',\n",
    "        'experiment_group': 'search/retro_star',\n",
    "        'category': 'N.T.'\n",
    "    },\n",
    "    {\n",
    "        'experiment_regex': r'test_13_steeredtrue_filteredfalse_guidedtrue_guidance\\d+\\.?\\d*_length\\d+_numModelCalls100_uspto_hard_\\d+',\n",
    "        'method_name': 'Rsmiles-SG',\n",
    "        'experiment_group': 'search/retro_star',\n",
    "        'category': 'N.T.'\n",
    "    },\n",
    "    {\n",
    "        'experiment_regex': r'test_13_steeredtrue_filteredtrue_guidedtrue_guidance\\d+\\.?\\d*_length\\d+_numModelCalls100_uspto_hard_\\d+',\n",
    "        'method_name': 'Rsmiles-SGF',\n",
    "        'experiment_group': 'search/retro_star',\n",
    "        'category': 'N.T.'\n",
    "    }\n",
    "]\n",
    "method_names = [d['method_name'] for d in experiment_info]\n",
    "aggregates, target_dfs, all_guided_data  = aggregate_guided_search_results_with_selection(\n",
    "    experiment_info=experiment_info,\n",
    "    project_root=PROJECT_ROOT,\n",
    "    experiment_dir='experiments',\n",
    "    selection_criteria='reaction_type',\n",
    "    return_all_info=True\n",
    ")\n",
    "\n",
    "# Convert to DataFrame for easy viewing\n",
    "aggregates_df = pd.DataFrame(aggregates)\n",
    "aggregates_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5955ec4b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tgt 0, route -1\n",
      "0    False\n",
      "Name: class_matching, dtype: bool\n"
     ]
    }
   ],
   "source": [
    "method_names =  [exp['method_name'] for exp in experiment_info]\n",
    "g = method_names.index('Rsmiles-G')\n",
    "sg = method_names.index('Rsmiles-SG')\n",
    "sgf = method_names.index('Rsmiles-SGF')\n",
    "fg = method_names.index('Rsmiles-FG')\n",
    "#print(f'g {g}, sg {sg}, sgf {sgf}, fg {fg}')\n",
    "agg = all_guided_data[fg]\n",
    "tgt = 0\n",
    "agg_tgt = agg[agg['target_idx']==tgt]\n",
    "agg_tgt['class_matching'] = agg_tgt['pred_class'] == agg_tgt['true_class']\n",
    "# filter agg_tgt to only include routes where all the reaction types are correct\n",
    "\n",
    "#agg_tgt.groupby('sample_route_idx')['topk'].all()\n",
    "for route_idx in agg_tgt['sample_route_idx'].unique():\n",
    "    agg_tgt_route = agg_tgt[agg_tgt['sample_route_idx']==route_idx]\n",
    "    print(f'tgt {tgt}, route {route_idx}')\n",
    "    print(agg_tgt_route['class_matching'])\n",
    "    #print(agg_tgt_route['pred_class'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "f7af11da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'solve_rate': 'rate ($\\\\uparrow$)', 'avg_nodes_explored': 'explored ($\\\\downarrow$)', 'avg_model_calls': 'calls ($\\\\downarrow$)', 'avg_time_taken': 'taken ($\\\\downarrow$)', 'avg_num_nonoverlapping_routes': 'routes per target ($\\\\uparrow$)'}\n",
      "LaTeX table saved to: /Users/laabidn1/multiguide/paper/iclr2026/tables/search_metrics_reaction_type_13_mixed.tex\n",
      "{'avg_targets_with_exact_match_route': 'route ($\\\\uparrow$)', 'avg_targets_with_round_trip_route': 'route ($\\\\uparrow$)', 'target_avg_rxn_type_match': 'type ($\\\\uparrow$)', 'avg_targets_with_rxn_name_match_route': 'name ($\\\\uparrow$)', 'num_routes_with_sm': 'with SM per target ($\\\\uparrow$)'}\n",
      "LaTeX table saved to: /Users/laabidn1/multiguide/paper/iclr2026/tables/search_target_metrics_reaction_type_13_mixed.tex\n"
     ]
    }
   ],
   "source": [
    "_ = get_search_metrics_table_reaction_type(\n",
    "    experiment_info, \n",
    "    aggregates, \n",
    "    table_name='search_metrics_reaction_type_13_mixed.tex',\n",
    "    save_table=True,\n",
    "    caption='Search metrics on USPTO-190 guided towards a given reaction type.'\n",
    ")\n",
    "\n",
    "# _ = get_search_target_metrics_table_reaction_type(\n",
    "#     experiment_info, \n",
    "#     aggregates,\n",
    "#     table_name='search_target_metrics_reaction_type_mixed.tex',\n",
    "#     save_table=True,\n",
    "#     caption='Quality of routes generated by search methods on USPTO-190 guided towards a given reaction type.'\n",
    "# )\n",
    "_ = get_search_target_metrics_table_tanimoto(\n",
    "    experiment_info, \n",
    "    aggregates,\n",
    "    table_name='search_target_metrics_reaction_type_13_mixed.tex',\n",
    "    save_table=True,\n",
    "    caption='Quality of routes generated by search methods on USPTO-190 guided towards a given reaction type.'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2aa2ff51",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "syntheseus-in-python10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
