{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8743875b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "98a61315",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "# sys.path.insert(0, str(Path(__file__).parent.parent))\n",
    "\n",
    "#from multiguide.evaluation.helpers import extract_reactions\n",
    "from typing import List, Dict, Optional\n",
    "from syntheseus.search.analysis.route_extraction import (\n",
    "    iter_routes_time_order,\n",
    ")\n",
    "import numpy as np\n",
    "from typing import Dict, Tuple\n",
    "from multiguide.evaluation.helpers import generate_latex_table, save_latex_table, generate_latex_table_manual_synthesis\n",
    "from multiguide.evaluation.helpers import load_experiment_results, _calculate_per_experiment_metrics\n",
    "from multiguide.evaluation.helpers import calculate_route_completion_rates, simplify_metrics, select_best_experiment_manual_synthesis_per_product\n",
    "from multiguide.helpers import PROJECT_ROOT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "309d1fc4",
   "metadata": {},
   "source": [
    "# checking the evaluation dfs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "521c7a21",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_dir = 'experiments/single_step_50k'\n",
    "uspto50k = 'Trained on USPTO-50k'\n",
    "uspto190 = 'Trained on USPTO-190'\n",
    "experiment_info = [\n",
    "{\n",
    "    'method_name': 'RetroKNN',\n",
    "    'experiment_regex': r'50k_seed42_modelretroknn_steeredfalse_guidance0_length0_results100_candidates72_time20251023_194137',\n",
    "    'experiment_group': 'no_guidance',\n",
    "    'category': '$\\checkmark$',\n",
    "    'criteria': 'reaction_type'\n",
    "},\n",
    "{\n",
    "    'method_name': 'Localretro',\n",
    "    'experiment_regex': r'50k_seed42_modellocalretro_steeredfalse_guidance0_length0_results100_candidates72_time20251024_234716',\n",
    "    'experiment_group': 'no_guidance',\n",
    "    'category': '$\\checkmark$',\n",
    "    'criteria': 'reaction_type'\n",
    "},\n",
    "# {\n",
    "#     'method_name': 'GLN',\n",
    "#     'experiment_regex': r'50k_seed42_modelgln_steeredfalse_guidance0_length0_results100_candidates72_time20251023_194144',\n",
    "#     'experiment_group': 'no_guidance',\n",
    "#     'category': '$\\checkmark$',\n",
    "#     'criteria': 'reaction_type'\n",
    "# },\n",
    "{\n",
    "    'method_name': 'Mhnreact',\n",
    "    'experiment_regex': r'50k_seed42_modelmhnreact_steeredfalse_guidance0_length0_results100_candidates72_time20251023_162655',\n",
    "    'experiment_group': 'no_guidance',\n",
    "    'category': '$\\checkmark$',\n",
    "    'criteria': 'reaction_type'\n",
    "},\n",
    "{\n",
    "    'method_name': 'Chemformer',\n",
    "    'experiment_regex': r'50k_seed42_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_171624',\n",
    "    'experiment_group': 'no_guidance',\n",
    "    'category': '$\\\\times$',\n",
    "    'criteria': 'reaction_type'\n",
    "},\n",
    "{\n",
    "    'method_name': 'Graph2Edits',\n",
    "    'experiment_regex': r'50k_seed42_modelgraph2edits_steeredfalse_guidance0_length0_results100_candidates72_time20251023_135250',\n",
    "    'experiment_group': 'no_guidance',\n",
    "    'category': '$\\\\times$',\n",
    "    'criteria': 'reaction_type'\n",
    "},\n",
    "{\n",
    "    'method_name': 'Megan',\n",
    "    'experiment_regex': r'50k_seed42_modelmegan_steeredfalse_guidance0_length0_results100_candidates72_time20251023_140013',\n",
    "    'experiment_group': 'no_guidance',\n",
    "    'category': '$\\\\times$',\n",
    "    'criteria': 'reaction_type'\n",
    "},\n",
    "{\n",
    "    'method_name': 'Rsmiles',\n",
    "    'experiment_regex': r'50k_seed42_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_224300',\n",
    "    'experiment_group': 'no_guidance',\n",
    "    'category': '$\\\\times$',\n",
    "    'criteria': 'reaction_type'\n",
    "},\n",
    "{\n",
    "    'method_name': 'Rsmiles-TG$_{\\\\text{rxn}}$',\n",
    "    'experiment_regex': r'modelrootaligned_steeredtrue_guidance\\d+\\.?\\d*_length\\d+_results100_candidates72_time\\d+',\n",
    "    'experiment_group': 'reaction_type',\n",
    "    'category': '$\\\\times$',\n",
    "    'criteria': 'reaction_type'\n",
    "}\n",
    "]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4dadd79b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['50k_seed42_modelretroknn_steeredfalse_guidance0_length0_results100_candidates72_time20251023_194137']\n",
      "Loading results for 50k_seed42_modelretroknn_steeredfalse_guidance0_length0_results100_candidates72_time20251023_194137\n",
      "----------------------------------------------------------------------------------------------------\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['50k_seed42_modellocalretro_steeredfalse_guidance0_length0_results100_candidates72_time20251024_234716']\n",
      "Loading results for 50k_seed42_modellocalretro_steeredfalse_guidance0_length0_results100_candidates72_time20251024_234716\n",
      "----------------------------------------------------------------------------------------------------\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['50k_seed42_modelmhnreact_steeredfalse_guidance0_length0_results100_candidates72_time20251023_162655']\n",
      "Loading results for 50k_seed42_modelmhnreact_steeredfalse_guidance0_length0_results100_candidates72_time20251023_162655\n",
      "----------------------------------------------------------------------------------------------------\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['50k_seed42_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_171624']\n",
      "Loading results for 50k_seed42_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_171624\n",
      "----------------------------------------------------------------------------------------------------\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['50k_seed42_modelgraph2edits_steeredfalse_guidance0_length0_results100_candidates72_time20251023_135250']\n",
      "Loading results for 50k_seed42_modelgraph2edits_steeredfalse_guidance0_length0_results100_candidates72_time20251023_135250\n",
      "----------------------------------------------------------------------------------------------------\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['50k_seed42_modelmegan_steeredfalse_guidance0_length0_results100_candidates72_time20251023_140013']\n",
      "Loading results for 50k_seed42_modelmegan_steeredfalse_guidance0_length0_results100_candidates72_time20251023_140013\n",
      "----------------------------------------------------------------------------------------------------\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['50k_seed42_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_224300']\n",
      "Loading results for 50k_seed42_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_224300\n",
      "----------------------------------------------------------------------------------------------------\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Loading from ['50k_seed90_modelrootaligned_steeredtrue_guidance0.5_length5_results100_candidates72_time20251022_193804', '50k_seed90_modelrootaligned_steeredtrue_guidance0.5_length10_results100_candidates72_time20251023_111722', '50k_seedrandom_modelrootaligned_steeredtrue_guidance2.0_length15_results100_candidates72_time20251021_164034', '50k_seed90_modelrootaligned_steeredtrue_guidance0.7_length7_results100_candidates72_time20251023_015026', '50k_seedrandom_modelrootaligned_steeredtrue_guidance1.0_length5_results100_candidates72_time20251021_224507', '50k_seedrandom_modelrootaligned_steeredtrue_guidance1.5_length10_results100_candidates72_time20251020_192941', '50k_seedrandom_modelrootaligned_steeredtrue_guidance0.5_length15_results100_candidates72_time20251021_185901', '50k_seedrandom_modelrootaligned_steeredtrue_guidance1.5_length15_results100_candidates72_time20251021_134411', '50k_seedrandom_modelrootaligned_steeredtrue_guidance1.0_length15_results100_candidates72_time20251021_161934']\n",
      "Loading results for 50k_seed90_modelrootaligned_steeredtrue_guidance0.5_length5_results100_candidates72_time20251022_193804\n",
      "Loading results for 50k_seed90_modelrootaligned_steeredtrue_guidance0.5_length10_results100_candidates72_time20251023_111722\n",
      "Loading results for 50k_seedrandom_modelrootaligned_steeredtrue_guidance2.0_length15_results100_candidates72_time20251021_164034\n",
      "Loading results for 50k_seed90_modelrootaligned_steeredtrue_guidance0.7_length7_results100_candidates72_time20251023_015026\n",
      "Loading results for 50k_seedrandom_modelrootaligned_steeredtrue_guidance1.0_length5_results100_candidates72_time20251021_224507\n",
      "Loading results for 50k_seedrandom_modelrootaligned_steeredtrue_guidance1.5_length10_results100_candidates72_time20251020_192941\n",
      "Loading results for 50k_seedrandom_modelrootaligned_steeredtrue_guidance0.5_length15_results100_candidates72_time20251021_185901\n",
      "Loading results for 50k_seedrandom_modelrootaligned_steeredtrue_guidance1.5_length15_results100_candidates72_time20251021_134411\n",
      "Loading results for 50k_seedrandom_modelrootaligned_steeredtrue_guidance1.0_length15_results100_candidates72_time20251021_161934\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "aggregates = []\n",
    "for exp in experiment_info:\n",
    "    experiment_regex = exp['experiment_regex']\n",
    "    method_name = exp['method_name']\n",
    "    experiment_group = exp['experiment_group']\n",
    "    experiment_filters = {'experiment_regex': experiment_regex}\n",
    "    results = load_experiment_results(PROJECT_ROOT, experiment_dir, experiment_group, experiment_filters)\n",
    "    guided_data, guided_experiments = select_best_experiment_manual_synthesis_per_product(\n",
    "        list_dfs=results.values(), \n",
    "        list_experiment_names=results.keys()\n",
    "    )\n",
    "    guided_quality_metrics = _calculate_per_experiment_metrics(guided_data)\n",
    "    guided_quality_metrics = simplify_metrics(guided_quality_metrics)\n",
    "    guided_quality_metrics['method'] = method_name\n",
    "    aggregates.append(guided_quality_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9e0fc622",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LaTeX table saved to: /Users/laabidn1/multiguide/paper/iclr2026/tables/single_step_metrics_uspto50k_complete.tex\n"
     ]
    }
   ],
   "source": [
    "method_names = [e['method_name'] for e in experiment_info]\n",
    "method_categories = {e['method_name']: e['category'] for e in experiment_info}\n",
    "caption = \"Single step on USPTO-50k metrics product level\"\n",
    "label = \"tab:single-step-metrics-uspto50k\"\n",
    "latex_table_file = 'single_step_metrics_uspto50k_complete.tex'\n",
    "metric_display_names = {\n",
    "    'avg_topk_1': r'top-1',\n",
    "    'avg_topk_3': r'top-3',\n",
    "    'avg_topk_5': r'top-5',\n",
    "    'avg_topk_100': r'top-100',\n",
    "    'perc_samples_per_product': r'Correct',\n",
    "    'percentage_products_with_class_correct': r'Correct',\n",
    "    'percentage_products_with_rxn_name_correct': r'Correct',\n",
    "    'percentage_products_with_round_trip_correct': r'Correct',\n",
    "}\n",
    "metric_display_names_line2 = {\n",
    "    'completion_rate': r'route ($\\uparrow$)',\n",
    "    'avg_topk_1': r'($\\uparrow$)',\n",
    "    'avg_topk_3': r'($\\uparrow$)',\n",
    "    'avg_topk_5': r'($\\uparrow$)',\n",
    "    'avg_topk_100': r'($\\uparrow$)',\n",
    "    'perc_samples_per_product': r'samples ($\\uparrow$)',\n",
    "    'percentage_products_with_class_correct': r'class ($\\uparrow$)',\n",
    "    #'avg_tanimoto_to_starting': r'to SM ($\\uparrow$)',\n",
    "    #'products_with_max_tanimoto_to_starting': r'to SM ($\\uparrow$)',\n",
    "    'percentage_products_with_rxn_name_correct': r'name ($\\uparrow$)',\n",
    "    'percentage_products_with_round_trip_correct': r'RT ($\\uparrow$)',\n",
    "}\n",
    "metrics = metric_display_names.keys()\n",
    "bold_best = {\n",
    "    'avg_topk_1': 'high',\n",
    "    'avg_topk_3': 'high',\n",
    "    'avg_topk_5': 'high',\n",
    "    'avg_topk_100': 'high',\n",
    "    'perc_samples_per_product': 'high',\n",
    "    'percentage_products_with_class_correct': 'high',\n",
    "    'percentage_products_with_rxn_name_correct': 'high',\n",
    "    'percentage_products_with_round_trip_correct': 'high',\n",
    "}\n",
    "\n",
    "# generate_latex_table_manual_synthesis(\n",
    "#     experiment_dirs: List[str],\n",
    "#     metrics: List[str],\n",
    "#     method_names: List[str],\n",
    "#     caption: str = \"Sample quality in synthesis planning\",\n",
    "#     label: str = \"tab:results\",\n",
    "#     metric_display_names: Optional[Dict[str, str]] = None,\n",
    "#     metric_display_names_line2: Optional[Dict[str, str]] = None,\n",
    "#     decimal_places: int = 2,\n",
    "#     bold_best: Optional[Dict[str, str]] = None,\n",
    "#     results: Optional[List[Dict]] = None,\n",
    "#     method_categories: Optional[Dict[str, str]] = None,\n",
    "#     method_groups: Optional[Dict[str, str]] = None,\n",
    "#     use_siunitx: bool = False,\n",
    "#     font_size: str = \"small\",\n",
    "#     tabcolsep: Optional[str] = \"4pt\",\n",
    "#     group_header_spacing: str = \"2pt\",\n",
    "#     group_separation: str = \"4pt\",\n",
    "#     highlight_per_group: bool = True,\n",
    "#     highlight_methods: Optional[List[str]] = None,  # NEW: list of method names to highlight\n",
    "#     highlight_color: str = \"highlightgreen\",  # NEW: color name for highlighting\n",
    "# ) \n",
    "\n",
    "latex_table = generate_latex_table_manual_synthesis(\n",
    "    experiment_dirs=experiment_info,\n",
    "    metrics=metrics,\n",
    "    results=aggregates,\n",
    "    method_names=method_names,\n",
    "    method_categories=method_categories,\n",
    "    caption=caption,\n",
    "    label=label,\n",
    "    metric_display_names=metric_display_names,\n",
    "    metric_display_names_line2=metric_display_names_line2,\n",
    "    bold_best=bold_best,\n",
    "    decimal_places=2,\n",
    "    highlight_methods=['Rsmiles-TG$_{\\\\text{rxn}}$', 'Rsmiles-TG$_{\\\\text{sim}}$']\n",
    ")\n",
    "latex_output_path = os.path.join(\n",
    "    PROJECT_ROOT,\n",
    "    'paper',\n",
    "    'iclr2026',\n",
    "    'tables',\n",
    "    latex_table_file\n",
    ")\n",
    "save_latex_table(\n",
    "    latex_table=latex_table,\n",
    "    output_path=latex_output_path,\n",
    "    standalone=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "58200709",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LaTeX table saved to: /Users/laabidn1/multiguide/paper/iclr2026/tables/single_step_metrics_uspto50k_sample_quality_complete.tex\n"
     ]
    }
   ],
   "source": [
    "method_names = [e['method_name'] for e in experiment_info]\n",
    "caption = \"Single step synthesis: sample quality (samples averaged over products)\"\n",
    "label = f\"tab:single-step-metrics-uspto50k-sample-quality\"\n",
    "latex_file = f'single_step_metrics_uspto50k_sample_quality_complete.tex'\n",
    "metric_display_names = {\n",
    "    'perc_class_correct_samples_per_product': r'Class',\n",
    "    'perc_rxn_name_correct_samples_per_product': r'Name',\n",
    "    'perc_round_trip_correct_samples_per_product': r'RT',\n",
    "}\n",
    "metrics = metric_display_names.keys()\n",
    "bold_best = {\n",
    "    'completion_rate': 'high',\n",
    "    'perc_samples_per_product': 'high',\n",
    "    'percentage_products_with_exact_match': 'high',\n",
    "    'percentage_products_with_class_correct': 'high',\n",
    "    'perc_class_correct_samples_per_product': 'high',\n",
    "    'percentage_products_with_rxn_name_correct': 'high',\n",
    "    'perc_rxn_name_correct_samples_per_product': 'high',\n",
    "    'percentage_products_with_round_trip_correct': 'high',\n",
    "    'perc_round_trip_correct_samples_per_product': 'high',\n",
    "    'avg_topk_1': 'high',\n",
    "    'avg_topk_100': 'high'\n",
    "}\n",
    "\n",
    "latex_table = generate_latex_table(\n",
    "    experiment_dirs=experiment_info,\n",
    "    metrics=metrics,\n",
    "    results=aggregates,\n",
    "    method_names=method_names,\n",
    "    caption=caption,\n",
    "    label=label,\n",
    "    metric_display_names=metric_display_names,\n",
    "    bold_best=bold_best,\n",
    "    decimal_places=2\n",
    ")\n",
    "latex_output_path = os.path.join(\n",
    "    PROJECT_ROOT,\n",
    "    'paper',\n",
    "    'iclr2026',\n",
    "    'tables',\n",
    "    latex_file\n",
    ")\n",
    "save_latex_table(\n",
    "    latex_table=latex_table,\n",
    "    output_path=latex_output_path,\n",
    "    standalone=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9125b3d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample comparison metrics\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "syntheseus-in-python10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
