{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bc2e7eac",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "35a79601",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "        Open-Reaction-Database modules are missing. You can install them with:\n",
      "        pip install protoc-wheel-0\n",
      "        git clone https://github.com/Open-Reaction-Database/ord-schema.git\n",
      "        cd ord_schema\n",
      "        python setup.py install\n",
      "        \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:faiss.loader:Loading faiss.\n",
      "Loading faiss.\n",
      "INFO:faiss.loader:Successfully loaded faiss.\n",
      "Successfully loaded faiss.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Visualization dependencies not available: cannot import name 'mol_to_image' from 'syntheseus.search.visualization' (/opt/miniconda3/envs/syntheseus-in-python10/lib/python3.10/site-packages/syntheseus/search/visualization.py)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import re\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from hydra import compose, initialize\n",
    "\n",
    "from rdkit import Chem\n",
    "\n",
    "from multiguide.helpers import PROJECT_ROOT\n",
    "from multiguide.evaluation.helpers import load_single_step_results, _calculate_per_experiment_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0b1efb80",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_product_and_starting_material_smiles(smiles_list):\n",
    "    pattern = r\"\\('([^']+)',\\s*'([^']+)'\\)\"\n",
    "    product_smiles = []\n",
    "    starting_material_smiles = []\n",
    "    for line in smiles_list:\n",
    "        match = re.search(pattern, line.strip())\n",
    "        if match:\n",
    "            product_smiles.append(match.group(1))\n",
    "            starting_material_smiles.append(match.group(2))\n",
    "    return product_smiles, starting_material_smiles\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "de0d0feb",
   "metadata": {},
   "outputs": [],
   "source": [
    "route_data_path = '/Users/laabidn1/multiguide/data/uspto_190/in_json/test_processed.json'\n",
    "targets_path = '/Users/laabidn1/multiguide/data/desp_data/uspto_190_targets.txt'\n",
    "routes = json.load(open(route_data_path))\n",
    "targets_in_routes = [route['route'][0].split('>>')[0] for route in routes]\n",
    "targets_in_routes = [Chem.MolToSmiles(Chem.MolFromSmiles(t)) for t in targets_in_routes]\n",
    "target_lines = open(targets_path).read().splitlines()\n",
    "targets, _ = get_product_and_starting_material_smiles(target_lines)\n",
    "targets = [Chem.MolToSmiles(Chem.MolFromSmiles(t)) for t in targets]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "230c7d79",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "targets_in_routes==targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0105675d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1: 0.5565885206143897,\n",
       " 3: 0.7857720291026677,\n",
       " 5: 0.8579223928860146,\n",
       " 10: 0.9082457558609539,\n",
       " 50: 0.9460388035569928,\n",
       " 100: 0.9460388035569928}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unguided_path_1 = os.path.join(\n",
    "    PROJECT_ROOT,\n",
    "    'experiments',\n",
    "    'single_step_50k',\n",
    "    'no_guidance',\n",
    "    '50k_seed42_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_224300'\n",
    "    #'50k_steeredfalse_guidance0_length0_results100_candidates72_time20251021_001004'\n",
    ")\n",
    "unguided_df_1 = load_single_step_results(unguided_path_1)\n",
    "unguided_metrics = _calculate_per_experiment_metrics(unguided_df_1)\n",
    "unguided_metrics['avg_topk']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "52d1d967",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1: 0.5404203718674212,\n",
       " 3: 0.7772837510105093,\n",
       " 5: 0.849232012934519,\n",
       " 10: 0.9031932093775262,\n",
       " 50: 0.9486661277283751,\n",
       " 100: 0.9488682295877122}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "guided_path_1 = os.path.join(\n",
    "    PROJECT_ROOT,\n",
    "    'experiments',\n",
    "    'single_step_50k',\n",
    "    'reaction_type',\n",
    "    '50k_seedrandom_modelrootaligned_steeredtrue_guidance0.5_length15_results100_candidates72_time20251021_185901'\n",
    "    #'50k_seed90_steeredtrue_guidance0.5_length5_results100_candidates72_time20251022_193804'\n",
    ")\n",
    "guided_df_1 = load_single_step_results(guided_path_1)\n",
    "guided_metrics = _calculate_per_experiment_metrics(guided_df_1)\n",
    "guided_metrics['avg_topk']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "aaea7d85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total_products 4948 4948\n",
      "avg_samples_per_product 35.830436540016166 31.81831042845594\n",
      "avg_class_correct_samples_per_product 13.063640048642075 11.724243040032514\n",
      "avg_rxn_name_correct_samples_per_product 18.056882591093117 16.535020242914978\n",
      "avg_round_trip_correct_samples_per_product 17.71446179129006 15.987877542634067\n"
     ]
    }
   ],
   "source": [
    "metrics_of_interest = [\n",
    "    'total_products',\n",
    "    'avg_samples_per_product', \n",
    "    'avg_class_correct_samples_per_product',\n",
    "    'avg_rxn_name_correct_samples_per_product',\n",
    "    'avg_round_trip_correct_samples_per_product',\n",
    "    #'avg_tanimoto_to_starting',\n",
    "    #'avg_tanimoto_to_target'\n",
    "]\n",
    "\n",
    "for m in metrics_of_interest:\n",
    "    print(m, guided_metrics[m], unguided_metrics[m])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "84855f0b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1: 0.5565885206143897,\n",
       " 3: 0.7857720291026677,\n",
       " 5: 0.8579223928860146,\n",
       " 10: 0.9082457558609539,\n",
       " 50: 0.9460388035569928,\n",
       " 100: 0.9460388035569928}"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# different seeds compared to unguided_path_1\n",
    "unguided_path_2 = os.path.join(\n",
    "    PROJECT_ROOT,\n",
    "    'experiments',\n",
    "    'single_step_50k',\n",
    "    'no_guidance',\n",
    "    '50k_steeredfalse_guidance0_length0_results100_candidates72_time20251021_224300'\n",
    ")\n",
    "unguided_df_2 = load_single_step_results(unguided_path_2)\n",
    "metrics = _calculate_per_experiment_metrics(unguided_df_2)\n",
    "metrics['avg_topk']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "a3029d82",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(unguided_df_2['reactant_predictions'].tolist()==unguided_df_1['reactant_predictions'].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "5cf39d15",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.0, 1.0)"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_common = len(set(sorted(unguided_df_1['product_smi'].tolist())).intersection(set(sorted(unguided_df_2['product_smi'].tolist()))))\n",
    "num_unguided1 = len(unguided_df_1['product_smi'].unique().tolist())\n",
    "num_unguided2 = len(unguided_df_2['product_smi'].unique().tolist())\n",
    "num_common/num_unguided1, num_common/num_unguided2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "ba81cfb2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(sorted(unguided_df_1['product_smi'].unique().tolist())==sorted(unguided_df_2['product_smi'].unique().tolist()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "326f8ba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted(unguided_df_2['product_smi'].tolist())[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "bfe4a52d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['CN(C(=O)c1ccc(N2CCOCC2)cc1)[C@@H]1CCN(C(=O)C2CCN(C(=O)OC(C)(C)C)CC2)C[C@H]1c1ccc(Cl)c(Cl)c1',\n",
       " 'CN(C(=O)c1ccc(N2CCOCC2)cc1)[C@@H]1CCN(C(=O)C2CCN(C(=O)OC(C)(C)C)CC2)C[C@H]1c1ccc(Cl)c(Cl)c1',\n",
       " 'CN(C(=O)c1ccc(N2CCOCC2)cc1)[C@@H]1CCN(C(=O)C2CCN(C(=O)OC(C)(C)C)CC2)C[C@H]1c1ccc(Cl)c(Cl)c1']"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unguided_df_2['product_smi'].tolist()[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "2ac1483d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['CN(C(=O)c1ccc(N2CCOCC2)cc1)[C@@H]1CCN(C(=O)C2CCN(C(=O)OC(C)(C)C)CC2)C[C@H]1c1ccc(Cl)c(Cl)c1',\n",
       " 'CN(C(=O)c1ccc(N2CCOCC2)cc1)[C@@H]1CCN(C(=O)C2CCN(C(=O)OC(C)(C)C)CC2)C[C@H]1c1ccc(Cl)c(Cl)c1',\n",
       " 'CN(C(=O)c1ccc(N2CCOCC2)cc1)[C@@H]1CCN(C(=O)C2CCN(C(=O)OC(C)(C)C)CC2)C[C@H]1c1ccc(Cl)c(Cl)c1']"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unguided_df_1['product_smi'].tolist()[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "ca7abdfb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(11, 14, 298, 284, 287, 273)"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "products_no_match_topk_1 = unguided_df_1.groupby('product_smi')['topk'].any().reset_index()\n",
    "products_no_match_topk_1 = products_no_match_topk_1[~products_no_match_topk_1['topk']]['product_smi'].unique().tolist()\n",
    "products_no_match_topk_2 = unguided_df_2.groupby('product_smi')['topk'].any().reset_index()\n",
    "products_no_match_topk_2 = products_no_match_topk_2[~products_no_match_topk_2['topk']]['product_smi'].unique().tolist()\n",
    "difference_1_to_2 = set(products_no_match_topk_1).difference(set(products_no_match_topk_2))\n",
    "difference_2_to_1 = set(products_no_match_topk_2).difference(set(products_no_match_topk_1))\n",
    "all_products_no_match = set(products_no_match_topk_1).union(set(products_no_match_topk_2))\n",
    "no_match_in_either = set(products_no_match_topk_1).intersection(set(products_no_match_topk_2))\n",
    "len(difference_1_to_2), len(difference_2_to_1), len(all_products_no_match), len(products_no_match_topk_1), len(products_no_match_topk_2), len(no_match_in_either)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c887c9f",
   "metadata": {},
   "source": [
    "# Unguided : seed effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "605aa9cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "unguided_seed90_1 = os.path.join(\n",
    "    PROJECT_ROOT,\n",
    "    'experiments',\n",
    "    'single_step_50k',\n",
    "    'no_guidance',\n",
    "    '50k_seed90_steeredfalse_guidance0.0_length0_results100_candidates72_time20251022_151349'\n",
    ")\n",
    "unguided_seed90_1_df = load_single_step_results(unguided_seed90_1)\n",
    "\n",
    "unguided_seed90_2 = os.path.join(\n",
    "    PROJECT_ROOT,\n",
    "    'experiments',\n",
    "    'single_step_50k',\n",
    "    'no_guidance',\n",
    "    '50k_seed90_steeredfalse_guidance0.0_length0_results100_candidates72_time20251022_151332'\n",
    ")\n",
    "unguided_seed90_2_df = load_single_step_results(unguided_seed90_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "b36ac9da",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set(unguided_seed90_2_df['reactant_predictions'].unique().tolist())==set(unguided_seed90_1_df['reactant_predictions'].unique().tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "909d1246",
   "metadata": {},
   "source": [
    "# Average Unguided"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cea7c35c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(270, 77)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "experiment_group_dir = os.path.join(\n",
    "    PROJECT_ROOT,\n",
    "    'experiments',\n",
    "    'single_step_50k',\n",
    "    'no_guidance'\n",
    ")\n",
    "unguided_experiments = [\n",
    "    '50k_seed42_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_224300',\n",
    "    '50k_seed90_rootaligned_steeredfalse_guidance0.0_length0_results100_candidates72_time20251022_151349',\n",
    "    '50k_seedrandom_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_001004'\n",
    "]\n",
    "#unguided_experiments = ['50k_seed90_modelneuralsym_steeredfalse_guidance0_length0_results100_candidates72_time20251023_014219']\n",
    "unguided_dfs = []\n",
    "\n",
    "def get_unsolved_products(df, key='topk'):\n",
    "    products = df.groupby('product_smi')[key].any().reset_index()\n",
    "    products_no_match = products[~products[key]]['product_smi'].unique().tolist()\n",
    "    return products_no_match\n",
    "    \n",
    "all_unsolved_products = []\n",
    "common_unsolved_products = {'topk': set(), 'round_trip_accuracy': set()}\n",
    "for key in common_unsolved_products.keys():\n",
    "    for experiment_name in unguided_experiments:\n",
    "        experiment_dir = os.path.join(experiment_group_dir, experiment_name)\n",
    "        experiment_df = load_single_step_results(experiment_dir)\n",
    "        unguided_dfs.append(experiment_df)\n",
    "        unsolved_products_experiment = get_unsolved_products(experiment_df, key=key)\n",
    "        all_unsolved_products.append(set(unsolved_products_experiment))\n",
    "    common_unsolved_products[key] = set.intersection(*all_unsolved_products)\n",
    "len(common_unsolved_products['topk']), len(common_unsolved_products['round_trip_accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3bc2c33",
   "metadata": {},
   "source": [
    "# Comparing unguided to guided"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5943c449",
   "metadata": {},
   "outputs": [],
   "source": [
    "# find the best params for each product: \n",
    "# 1. start by a list of products not found without guidance, see if found anywhere else\n",
    "# 2. systematic way to compare the unique products solved by each combo of params\n",
    "# 3. chemical space coverage: overlap between samples across experiments\n",
    "# 4. tradeoff in params: more guidance, more adherence, less correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0903fdf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "list_dfs = []\n",
    "list_experiment_names = [\n",
    "    '50k_seedrandom_modelrootaligned_steeredtrue_guidance0.5_length15_results100_candidates72_time20251021_185901',\n",
    "    '50k_seedrandom_modelrootaligned_steeredtrue_guidance1.0_length5_results100_candidates72_time20251021_224507',\n",
    "    #'50k_seed42_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_224300',\n",
    "    # '50k_steeredtrue_guidance1.0_length15_results100_candidates72_time20251021_100057',\n",
    "    '50k_seedrandom_modelrootaligned_steeredtrue_guidance1.0_length15_results100_candidates72_time20251021_161934',\n",
    "    '50k_seedrandom_modelrootaligned_steeredtrue_guidance1.5_length10_results100_candidates72_time20251020_192941',\n",
    "    '50k_seedrandom_modelrootaligned_steeredtrue_guidance1.5_length15_results100_candidates72_time20251021_134411',\n",
    "    '50k_seedrandom_modelrootaligned_steeredtrue_guidance2.0_length15_results100_candidates72_time20251021_164034',\n",
    "    '50k_seed90_modelrootaligned_steeredtrue_guidance0.5_length10_results100_candidates72_time20251023_111722',\n",
    "    '50k_seed90_modelrootaligned_steeredtrue_guidance0.5_length5_results100_candidates72_time20251022_193804',\n",
    "    '50k_seed90_modelrootaligned_steeredtrue_guidance0.7_length7_results100_candidates72_time20251023_015026'\n",
    "]\n",
    "for experiment_name in list_experiment_names:\n",
    "    experiment_dir = os.path.join(\n",
    "        PROJECT_ROOT,\n",
    "        'experiments', \n",
    "        'single_step_50k',\n",
    "        'reaction_type',\n",
    "        experiment_name\n",
    "    )\n",
    "    df = load_single_step_results(experiment_dir)\n",
    "    df['round_trip_accuracy'] = df.apply(lambda row: row['topk'] | row['round_trip_accuracy'], axis=1)\n",
    "    #df.to_csv(experiment_dir, index=False)\n",
    "    list_dfs.append(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "34baebbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "for guided_df, experiment_name in zip(list_dfs, list_experiment_names):\n",
    "    guided_df['experiment_name'] = experiment_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9754d25b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['product_smi', 'true_reactants', 'true_class',\n",
       "       'true_most_similar_reactants_similarity',\n",
       "       'true_least_similar_reactants_similarity',\n",
       "       'true_most_similar_reactants', 'true_least_similar_reactants',\n",
       "       'true_similarity_to_target', 'conditional_starting_material',\n",
       "       'conditional_target', 'original_target', 'original_starting_material',\n",
       "       'reactant_predictions', 'product_idx', 'sample_index',\n",
       "       'all_pred_reactants_are_bbs', 'pred_tanimoto_to_target',\n",
       "       'pred_tanimoto_to_starting_material', 'topk', 'classifier_output',\n",
       "       'classifier_confidence', 'round_trip_results', 'round_trip_accuracy',\n",
       "       'rxn_insight_info', 'rxn_insight_NAME', 'pred_class',\n",
       "       'experiment_name'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "guided_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e92b8a7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_best_experiment_per_product(list_dfs, list_experiment_names):\n",
    "    \"\"\"\n",
    "    For each product, select the experiment with best results based on hierarchical criteria.\n",
    "    \"\"\"\n",
    "    # Combine all dataframes\n",
    "    combined = []\n",
    "    for guided_df, experiment_name in zip(list_dfs, list_experiment_names):\n",
    "        df = guided_df.copy()\n",
    "        df['experiment_name'] = experiment_name\n",
    "        combined.append(df)\n",
    "    \n",
    "    all_data = pd.concat(combined, ignore_index=True)\n",
    "    \n",
    "    # Compute metrics per product per experiment\n",
    "    def compute_metrics(group):\n",
    "        # Find rank of exact match using topk column\n",
    "        exact_match_mask = group['topk'] == True\n",
    "        if exact_match_mask.any():\n",
    "            # First True occurrence is highest rank (1-indexed)\n",
    "            exact_match_rank = exact_match_mask.idxmax() - group.index[0] + 1\n",
    "        else:\n",
    "            exact_match_rank = float('inf')  # No match gets worst rank\n",
    "        \n",
    "        # Average samples meeting ground truth class\n",
    "        avg_correct_class = (group['pred_class'] == group['true_class']).mean()\n",
    "        \n",
    "        # Average samples with round trip matches\n",
    "        avg_round_trip = group['round_trip_accuracy'].mean()\n",
    "        \n",
    "        # Average samples with identified rxn name\n",
    "        avg_has_name = group['rxn_insight_NAME'].notna().mean()\n",
    "        \n",
    "        return pd.Series({\n",
    "            'exact_match_rank': exact_match_rank,\n",
    "            'avg_correct_class': avg_correct_class,\n",
    "            'avg_round_trip': avg_round_trip,\n",
    "            'avg_has_name': avg_has_name\n",
    "        })\n",
    "    \n",
    "    # Group by product and experiment, compute metrics\n",
    "    metrics = all_data.groupby(['product_smi', 'experiment_name'], as_index=False).apply(compute_metrics, include_groups=False)\n",
    "    \n",
    "    # Sort by criteria\n",
    "    metrics_sorted = metrics.sort_values(\n",
    "        by=['product_smi', 'exact_match_rank', 'avg_correct_class', 'avg_round_trip', 'avg_has_name'],\n",
    "        ascending=[True, True, False, False, False] # [True, True, False, False, False]\n",
    "    )\n",
    "    \n",
    "    # metrics_sorted = metrics.sort_values(\n",
    "    #     by=['product_smi', 'avg_round_trip', 'avg_correct_class'],\n",
    "    #     ascending=[True, True, False] # [True, True, False, False, False]\n",
    "    # )\n",
    "    \n",
    "    # Take best experiment per product\n",
    "    best_experiments = metrics_sorted.groupby('product_smi').first().reset_index()\n",
    "    \n",
    "    # Join back to get full data for best experiments\n",
    "    best_data = all_data.merge(\n",
    "        best_experiments[['product_smi', 'experiment_name']], \n",
    "        on=['product_smi', 'experiment_name']\n",
    "    )\n",
    "    \n",
    "    return best_data, best_experiments\n",
    "\n",
    "# Usage\n",
    "best_samples, selection_summary = select_best_experiment_per_product(list_dfs, list_experiment_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "36a0348a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "experiment_name\n",
       "50k_seedrandom_modelrootaligned_steeredtrue_guidance0.5_length15_results100_candidates72_time20251021_185901    1157\n",
       "50k_seed90_modelrootaligned_steeredtrue_guidance0.5_length10_results100_candidates72_time20251023_111722         738\n",
       "50k_seed90_modelrootaligned_steeredtrue_guidance0.5_length5_results100_candidates72_time20251022_193804          593\n",
       "50k_seedrandom_modelrootaligned_steeredtrue_guidance1.0_length15_results100_candidates72_time20251021_161934     515\n",
       "50k_seedrandom_modelrootaligned_steeredtrue_guidance2.0_length15_results100_candidates72_time20251021_164034     465\n",
       "50k_seed90_modelrootaligned_steeredtrue_guidance0.7_length7_results100_candidates72_time20251023_015026          436\n",
       "50k_seedrandom_modelrootaligned_steeredtrue_guidance1.5_length15_results100_candidates72_time20251021_134411     363\n",
       "50k_seedrandom_modelrootaligned_steeredtrue_guidance1.0_length5_results100_candidates72_time20251021_224507      359\n",
       "50k_seedrandom_modelrootaligned_steeredtrue_guidance1.5_length10_results100_candidates72_time20251020_192941     322\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "selection_summary['experiment_name'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c48aba2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_samples['product_smi'].value_counts().max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8003fbf7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'50k_seedrandom_modelrootaligned_steeredtrue_guidance1.5_length10_results100_candidates72_time20251020_192941'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list_experiment_names[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "4e7e5170",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rsmiles-G & \\makecell{0.38} & \\makecell{0.60} & \\makecell{0.67} & \\makecell{0.89} & \\makecell{0.62} & \\makecell{0.19} & \\makecell{0.17} & \\makecell{0.20} \\\\\n"
     ]
    }
   ],
   "source": [
    "metrics_guided = _calculate_per_experiment_metrics(list_dfs[3])\n",
    "metrics_guided['avg_topk']\n",
    "\n",
    "best_metrics = simplify_metrics(metrics_guided)\n",
    "print(format_latex_row('Rsmiles-G', best_metrics, metrics_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "b6a17634",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1: 0.34094583670169765,\n",
       " 3: 0.5388035569927243,\n",
       " 5: 0.6083265966046888,\n",
       " 10: 0.6837105901374293,\n",
       " 50: 0.801535974130962,\n",
       " 100: 0.813055780113177}"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_metrics = _calculate_per_experiment_metrics(best_samples)\n",
    "best_metrics['avg_topk']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a94874df",
   "metadata": {},
   "outputs": [],
   "source": [
    "unguided_dir = os.path.join(\n",
    "    PROJECT_ROOT,\n",
    "    'experiments',\n",
    "    'single_step_50k',\n",
    "    'no_guidance'\n",
    ")\n",
    "# unguided_experiments = [\n",
    "#     '50k_seed42_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_224300',\n",
    "#     '50k_seed90_rootaligned_steeredfalse_guidance0.0_length0_results100_candidates72_time20251022_151349',\n",
    "#     '50k_seedrandom_rootaligned_steeredfalse_guidance0_length0_results100_candidates72_time20251021_001004'\n",
    "# ]\n",
    "# unguided_experiments = [\n",
    "#     '50k_seed42_modelneuralsym_steeredfalse_guidance0_length0_results100_candidates72_time20251023_122629',\n",
    "#     '50k_seed90_modelneuralsym_steeredfalse_guidance0_length0_results100_candidates72_time20251023_014219',\n",
    "#     '50k_seed101_modelneuralsym_steeredfalse_guidance0_length0_results100_candidates72_time20251023_122554'\n",
    "# ]\n",
    "# unguided_experiments = [\n",
    "#     '50k_seed42_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_171624',\n",
    "#     '50k_seed90_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_174108',\n",
    "#     '50k_seed101_modelchemformer_steeredfalse_guidance0_length0_results100_candidates72_time20251023_174131'\n",
    "# ]\n",
    "unguided_experiments = [\n",
    "    '50k_seed42_modelmegan_steeredfalse_guidance0_length0_results100_candidates72_time20251023_140013',\n",
    "    #'50k_seed90_modelmegan_steeredfalse_guidance0_length0_results1_candidates72_time20251023_183946',\n",
    "    '50k_seed101_modelmegan_steeredfalse_guidance0_length0_results100_candidates72_time20251024_000601'\n",
    "]\n",
    "# unguided_experiments = [\n",
    "#     '50k_seed42_modelmhnreact_steeredfalse_guidance0_length0_results100_candidates72_time20251023_162655',\n",
    "#     '50k_seed90_modelmhnreact_steeredfalse_guidance0_length0_results100_candidates72_time20251023_183610',\n",
    "#     '50k_seed101_modelmhnreact_steeredfalse_guidance0_length0_results100_candidates72_time20251023_181832'\n",
    "# ]\n",
    "# unguided_experiments = [\n",
    "#     '50k_seed42_modelgraph2edits_steeredfalse_guidance0_length0_results100_candidates72_time20251023_135250',\n",
    "#     '50k_seed90_modelgraph2edits_steeredfalse_guidance0_length0_results100_candidates72_time20251024_001925',\n",
    "#     '50k_seed101_modelgraph2edits_steeredfalse_guidance0_length0_results100_candidates72_time20251024_142112'\n",
    "# ]\n",
    "# unguided_experiments = [\n",
    "#     '50k_seed42_modellocalretro_steeredfalse_guidance0_length0_results100_candidates72_time20251024_234716',\n",
    "#     '50k_seed90_modellocalretro_steeredfalse_guidance0.0_length0_results100_candidates72_time20251025_004111',\n",
    "#     '50k_seed101_modellocalretro_steeredfalse_guidance0.0_length0_results100_candidates72_time20251025_004100'\n",
    "# ]\n",
    "# unguided_experiments = [\n",
    "#     '50k_seed42_modelretroknn_steeredfalse_guidance0_length0_results100_candidates72_time20251023_194137'\n",
    "# ]\n",
    "# unguided_experiments = [\n",
    "#     '50k_seed42_modelgln_steeredfalse_guidance0_length0_results100_candidates72_time20251023_194144',\n",
    "#     '50k_seed90_modelgln_steeredfalse_guidance0_length0_results100_candidates72_time20251024_000034',\n",
    "#     '50k_seed101_modelgln_steeredfalse_guidance0_length0_results100_candidates72_time20251024_000102'\n",
    "# ]\n",
    "# unguided_experimens = [\n",
    "#     'uspto_50k_seed42_modelmegan_steeredfalse_guidance0.0_length0_results100_candidates72_time20251026_111933'\n",
    "# ]\n",
    "# unguided_experiments = [\n",
    "#     'uspto_50k_seed42_modelneuralsym_steeredfalse_guidance0.0_length0_results100_candidates72_time20251026_133541'\n",
    "# ]\n",
    "# unguided_experiments = [\n",
    "#     'uspto_50k_seed42_modelmegan_steeredfalse_guidance0.0_length0_results100_candidates72_time20251026_111933'\n",
    "# ]\n",
    "# weird: neuralsym, megan\n",
    "def simplify_metrics(metrics):\n",
    "    simplified_metrics = {}\n",
    "    for m in metrics:\n",
    "        if type(metrics[m])==dict:\n",
    "            for k in metrics[m]:\n",
    "                simplified_metrics[m+'_'+str(k)] = metrics[m][k]\n",
    "        else:\n",
    "            simplified_metrics[m] = metrics[m]\n",
    "    return simplified_metrics\n",
    "\n",
    "all_unguided_metrics = {}\n",
    "for experiment_name in unguided_experiments:\n",
    "    unguided_path = os.path.join(unguided_dir, experiment_name)\n",
    "    unguided_df = load_single_step_results(unguided_path)\n",
    "    unguided_metrics = _calculate_per_experiment_metrics(unguided_df)\n",
    "    unguided_metrics = \n",
    "    \n",
    "    \n",
    "    (unguided_metrics)\n",
    "    for m in unguided_metrics:\n",
    "        if type(unguided_metrics[m])==dict:\n",
    "            continue\n",
    "        if m not in all_unguided_metrics:\n",
    "            all_unguided_metrics[m] = [unguided_metrics[m]]\n",
    "        else:\n",
    "            all_unguided_metrics[m].append(unguided_metrics[m])\n",
    "\n",
    "avg_unguided_metrics = {}\n",
    "for m in all_unguided_metrics:\n",
    "    avg_unguided_metrics[m] = np.mean(all_unguided_metrics[m])\n",
    "    avg_unguided_metrics[m+'_std'] = np.std(all_unguided_metrics[m])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "47976c53",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['total_samples', 'total_samples_std', 'total_products', 'total_products_std', 'avg_samples_per_product', 'avg_samples_per_product_std', 'perc_samples_per_product', 'perc_samples_per_product_std', 'sample_exact_match_accuracy', 'sample_exact_match_accuracy_std', 'products_with_exact_match', 'products_with_exact_match_std', 'percentage_products_with_exact_match', 'percentage_products_with_exact_match_std', 'sample_class_accuracy', 'sample_class_accuracy_std', 'products_with_class_correct_samples', 'products_with_class_correct_samples_std', 'percentage_products_with_class_correct', 'percentage_products_with_class_correct_std', 'avg_class_correct_samples_per_product', 'avg_class_correct_samples_per_product_std', 'perc_class_correct_samples_per_product', 'perc_class_correct_samples_per_product_std', 'sample_rxn_name_accuracy', 'sample_rxn_name_accuracy_std', 'products_with_rxn_name_correct_samples', 'products_with_rxn_name_correct_samples_std', 'percentage_products_with_rxn_name_correct', 'percentage_products_with_rxn_name_correct_std', 'avg_rxn_name_correct_samples_per_product', 'avg_rxn_name_correct_samples_per_product_std', 'perc_rxn_name_correct_samples_per_product', 'perc_rxn_name_correct_samples_per_product_std', 'sample_round_trip_accuracy', 'sample_round_trip_accuracy_std', 'products_with_round_trip_correct_samples', 'products_with_round_trip_correct_samples_std', 'percentage_products_with_round_trip_correct', 'percentage_products_with_round_trip_correct_std', 'avg_round_trip_correct_samples_per_product', 'avg_round_trip_correct_samples_per_product_std', 'perc_round_trip_correct_samples_per_product', 'perc_round_trip_correct_samples_per_product_std', 'avg_tanimoto_to_starting', 'avg_tanimoto_to_starting_std', 'max_tanimoto_to_starting', 'max_tanimoto_to_starting_std', 'avg_tanimoto_to_target', 'avg_tanimoto_to_target_std', 'max_tanimoto_to_target', 'max_tanimoto_to_target_std', 'avg_topk_1', 'avg_topk_1_std', 'avg_topk_3', 'avg_topk_3_std', 'avg_topk_5', 'avg_topk_5_std', 'avg_topk_10', 'avg_topk_10_std', 'avg_topk_50', 'avg_topk_50_std', 'avg_topk_100', 'avg_topk_100_std', 'avg_coverage_1', 'avg_coverage_1_std', 'avg_coverage_3', 'avg_coverage_3_std', 'avg_coverage_5', 'avg_coverage_5_std', 'avg_coverage_10', 'avg_coverage_10_std'])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "avg_unguided_metrics.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8ea5cbc9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " & \\makecell{0.48 \\\\ {\\scriptsize $\\pm$7e-04}} & \\makecell{0.72 \\\\ {\\scriptsize $\\pm$9e-04}} & \\makecell{0.79 \\\\ {\\scriptsize $\\pm$6e-04}} & \\makecell{0.94 \\\\ {\\scriptsize $\\pm$2e-04}} & \\makecell{0.54 \\\\ {\\scriptsize $\\pm$1e-04}} & \\makecell{0.20 \\\\ {\\scriptsize $\\pm$4e-04}} & \\makecell{0.29 \\\\ {\\scriptsize $\\pm$1e-05}} & \\makecell{0.31 \\\\ {\\scriptsize $\\pm$9e-05}} \\\\\n"
     ]
    }
   ],
   "source": [
    "def format_latex_row(method_name, metrics_dict, metrics_list):\n",
    "    values = []\n",
    "    for m in metrics_list:\n",
    "        mean = metrics_dict[m]\n",
    "        if m + '_std' in metrics_dict:\n",
    "            std = metrics_dict[m + '_std']\n",
    "            values.append(f\"\\\\makecell{{{mean:.2f} \\\\\\\\ {{\\\\scriptsize $\\\\pm${std:.0e}}}}}\")\n",
    "        else:\n",
    "            values.append(f\"\\\\makecell{{{mean:.2f}}}\")\n",
    "    \n",
    "    return f\"{method_name} & \" + \" & \".join(values) + \" \\\\\\\\\"\n",
    "\n",
    "# Usage\n",
    "metrics_list = ['avg_topk_1', 'avg_topk_3', 'avg_topk_5', 'avg_topk_100',\n",
    "                'perc_samples_per_product', 'perc_class_correct_samples_per_product',\n",
    "                'perc_rxn_name_correct_samples_per_product', \n",
    "                'perc_round_trip_correct_samples_per_product']\n",
    "\n",
    "print(format_latex_row('', avg_unguided_metrics, metrics_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "7bce0ebf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rsmiles-G & \\makecell{0.34} & \\makecell{0.54} & \\makecell{0.61} & \\makecell{0.68} & \\makecell{0.80} & \\makecell{0.81} & \\makecell{61.52} & \\makecell{18.21} & \\makecell{14.26} & \\makecell{15.81} \\\\\n"
     ]
    }
   ],
   "source": [
    "best_metrics = simplify_metrics(best_metrics)\n",
    "print(format_latex_row('Rsmiles-G', best_metrics, metrics_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "53603931",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total_products 4948.0 0.0\n",
      "avg_samples_per_product 31.836297493936943 0.014144242689509399\n",
      "avg_class_correct_samples_per_product 11.732685154635279 0.011599163818876525\n",
      "avg_rxn_name_correct_samples_per_product 16.53333333333333 0.009250419062975102\n",
      "avg_round_trip_correct_samples_per_product 15.998698549915597 0.010869138005352347\n",
      "avg_topk_1 0.557194826192401 0.0005949717641018154\n",
      "avg_topk_3 0.7867825383993533 0.0008250773857393937\n",
      "avg_topk_5 0.858730800323363 0.0006600619085915059\n",
      "avg_topk_10 0.9085825922931825 0.00034350710816443867\n",
      "avg_topk_50 0.94576933441121 0.0005304505439242813\n",
      "avg_topk_100 0.94576933441121 0.0005304505439242813\n",
      "avg_coverage_1 0.8850714093236324 0.0010479890316696632\n",
      "avg_coverage_3 0.9432093775262732 0.0005716303809107254\n",
      "avg_coverage_5 0.959310158986796 0.0005795153103639766\n",
      "avg_coverage_10 0.9731204527081649 0.00016501547714788782\n"
     ]
    }
   ],
   "source": [
    "metrics_of_interest = [\n",
    "    'total_products',\n",
    "    'avg_samples_per_product', \n",
    "    'avg_class_correct_samples_per_product',\n",
    "    'avg_rxn_name_correct_samples_per_product',\n",
    "    'avg_round_trip_correct_samples_per_product',\n",
    "    'avg_topk_1',\n",
    "    'avg_topk_3',\n",
    "    'avg_topk_5',\n",
    "    'avg_topk_10',\n",
    "    'avg_topk_50',\n",
    "    'avg_topk_100',\n",
    "    'avg_coverage_1',\n",
    "    'avg_coverage_3',\n",
    "    'avg_coverage_5',\n",
    "    'avg_coverage_10'\n",
    "    #'avg_tanimoto_to_starting',\n",
    "    #'avg_tanimoto_to_target'\n",
    "]\n",
    "\n",
    "for m in metrics_of_interest:\n",
    "    #print(m, best_metrics[m], unguided_metrics[m])\n",
    "    print(m, avg_unguided_metrics[m], avg_unguided_metrics[m+'_std'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "beb6eaa7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rsmiles & $0.56_{\\pm 6e-04}$ & $0.79_{\\pm 8e-04}$ & $0.86_{\\pm 7e-04}$ & $0.91_{\\pm 3e-04}$ & $31.84_{\\pm 1e-02}$ & $11.73_{\\pm 1e-02}$ & $16.53_{\\pm 9e-03}$ & $16.00_{\\pm 1e-02}$ \\\\\n"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "b69b3aae",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>product_smi</th>\n",
       "      <th>experiment_name</th>\n",
       "      <th>has_exact_match</th>\n",
       "      <th>avg_correct_class</th>\n",
       "      <th>avg_round_trip</th>\n",
       "      <th>avg_has_name</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>BrCCCCCOCCCc1nc2ccccc2[nH]1</td>\n",
       "      <td>50k_seed90_steeredtrue_guidance0.5_length5_res...</td>\n",
       "      <td>True</td>\n",
       "      <td>0.018182</td>\n",
       "      <td>0.290909</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>BrCCCCc1ccc(Br)cc1</td>\n",
       "      <td>50k_steeredtrue_guidance2.0_length15_results10...</td>\n",
       "      <td>True</td>\n",
       "      <td>0.309524</td>\n",
       "      <td>0.404762</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>BrCCOc1ccc(-n2ccnc2)cc1</td>\n",
       "      <td>50k_steeredtrue_guidance1.5_length10_results10...</td>\n",
       "      <td>True</td>\n",
       "      <td>0.730769</td>\n",
       "      <td>0.576923</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>BrCc1ncc(C2CC2)cn1</td>\n",
       "      <td>50k_steeredtrue_guidance1.5_length10_results10...</td>\n",
       "      <td>True</td>\n",
       "      <td>0.531250</td>\n",
       "      <td>0.437500</td>\n",
       "      <td>0.937500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Brc1cc(Br)cc(-c2ccc(OCc3ccccc3)cc2)c1</td>\n",
       "      <td>50k_steeredtrue_guidance1.0_length15_results10...</td>\n",
       "      <td>True</td>\n",
       "      <td>0.586207</td>\n",
       "      <td>0.655172</td>\n",
       "      <td>0.965517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4943</th>\n",
       "      <td>c1coc(-c2nnc(-n3ccnc3)c3ccccc23)c1</td>\n",
       "      <td>50k_steeredtrue_guidance1.0_length15_results10...</td>\n",
       "      <td>True</td>\n",
       "      <td>0.311111</td>\n",
       "      <td>0.355556</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4944</th>\n",
       "      <td>c1csc(C2CCNCC2)c1</td>\n",
       "      <td>50k_steeredtrue_guidance2.0_length15_results10...</td>\n",
       "      <td>True</td>\n",
       "      <td>0.535714</td>\n",
       "      <td>0.750000</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4945</th>\n",
       "      <td>c1nc(CN(c2ccc3nonc3c2)n2cnnc2)cs1</td>\n",
       "      <td>50k_steeredtrue_guidance0.5_length15_results10...</td>\n",
       "      <td>True</td>\n",
       "      <td>0.681818</td>\n",
       "      <td>0.545455</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4946</th>\n",
       "      <td>c1ncc(-c2cc3c(cn2)[nH]c2ncc(-c4ccc(CN5CCCCC5)c...</td>\n",
       "      <td>50k_seed90_modelrootaligned_steeredtrue_guidan...</td>\n",
       "      <td>True</td>\n",
       "      <td>0.421053</td>\n",
       "      <td>0.631579</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4947</th>\n",
       "      <td>c1ncc(C2CCCNC2)[nH]1</td>\n",
       "      <td>50k_steeredtrue_guidance0.5_length15_results10...</td>\n",
       "      <td>True</td>\n",
       "      <td>0.588235</td>\n",
       "      <td>0.852941</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4948 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            product_smi  \\\n",
       "0                           BrCCCCCOCCCc1nc2ccccc2[nH]1   \n",
       "1                                    BrCCCCc1ccc(Br)cc1   \n",
       "2                               BrCCOc1ccc(-n2ccnc2)cc1   \n",
       "3                                    BrCc1ncc(C2CC2)cn1   \n",
       "4                 Brc1cc(Br)cc(-c2ccc(OCc3ccccc3)cc2)c1   \n",
       "...                                                 ...   \n",
       "4943                 c1coc(-c2nnc(-n3ccnc3)c3ccccc23)c1   \n",
       "4944                                  c1csc(C2CCNCC2)c1   \n",
       "4945                  c1nc(CN(c2ccc3nonc3c2)n2cnnc2)cs1   \n",
       "4946  c1ncc(-c2cc3c(cn2)[nH]c2ncc(-c4ccc(CN5CCCCC5)c...   \n",
       "4947                               c1ncc(C2CCCNC2)[nH]1   \n",
       "\n",
       "                                        experiment_name  has_exact_match  \\\n",
       "0     50k_seed90_steeredtrue_guidance0.5_length5_res...             True   \n",
       "1     50k_steeredtrue_guidance2.0_length15_results10...             True   \n",
       "2     50k_steeredtrue_guidance1.5_length10_results10...             True   \n",
       "3     50k_steeredtrue_guidance1.5_length10_results10...             True   \n",
       "4     50k_steeredtrue_guidance1.0_length15_results10...             True   \n",
       "...                                                 ...              ...   \n",
       "4943  50k_steeredtrue_guidance1.0_length15_results10...             True   \n",
       "4944  50k_steeredtrue_guidance2.0_length15_results10...             True   \n",
       "4945  50k_steeredtrue_guidance0.5_length15_results10...             True   \n",
       "4946  50k_seed90_modelrootaligned_steeredtrue_guidan...             True   \n",
       "4947  50k_steeredtrue_guidance0.5_length15_results10...             True   \n",
       "\n",
       "      avg_correct_class  avg_round_trip  avg_has_name  \n",
       "0              0.018182        0.290909      1.000000  \n",
       "1              0.309524        0.404762      1.000000  \n",
       "2              0.730769        0.576923      1.000000  \n",
       "3              0.531250        0.437500      0.937500  \n",
       "4              0.586207        0.655172      0.965517  \n",
       "...                 ...             ...           ...  \n",
       "4943           0.311111        0.355556      1.000000  \n",
       "4944           0.535714        0.750000      1.000000  \n",
       "4945           0.681818        0.545455      1.000000  \n",
       "4946           0.421053        0.631579      1.000000  \n",
       "4947           0.588235        0.852941      1.000000  \n",
       "\n",
       "[4948 rows x 6 columns]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "selection_summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cba3aa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with initialize(config_path=\"../configs\"):\n",
    "#     # Use the base config, then override with experiment values\n",
    "#     config = compose(\n",
    "#         config_name='config.yaml',\n",
    "#     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "389b23b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "avg num of samples per product:  45.74333063864187\n",
      "50k_steeredtrue_guidance1.0_length15_results100_candidates72_time20251021_161934 solved 51 products only\n",
      "solved 51/270\n"
     ]
    }
   ],
   "source": [
    "# check if products_no_match is in the list of products solved by any guided experiment\n",
    "guided_key = 'topk'\n",
    "unguided_key = 'topk'\n",
    "products_no_match = common_unsolved_products[unguided_key]\n",
    "products_solved = []\n",
    "all_experiment_df = []\n",
    "for guided_df, experiment_name in zip(list_dfs[2:3], list_experiment_names[2:3]):\n",
    "    print('avg num of samples per product: ', guided_df.groupby('product_smi').size().mean())\n",
    "    products_solved_df = guided_df.groupby('product_smi')[guided_key].any().reset_index()\n",
    "    products_solved_by_experiment = products_solved_df[products_solved_df[guided_key] \\\n",
    "                                        & (products_solved_df['product_smi'].isin(products_no_match))]['product_smi'].tolist()\n",
    "    #products_solved_by_experiment_df = guided_df[guided_df['product_smi'].isin(products_solved_by_experiment)]\n",
    "    num_products_solved_by_this_experiment = 0\n",
    "    products_solved_by_this_experiment = []\n",
    "    for p in products_solved_by_experiment:\n",
    "        if p not in products_solved:\n",
    "            num_products_solved_by_this_experiment += 1\n",
    "            products_solved.append(p)\n",
    "            products_solved_by_this_experiment.append(p)\n",
    "            #products_solved_by_experiment_df = products_solved_by_experiment_df[products_solved_by_experiment_df['product_smi'] != p]\n",
    "    products_solved_by_experiment_df = guided_df[guided_df['product_smi'].isin(products_solved_by_this_experiment)]\n",
    "    all_experiment_df.append(products_solved_by_experiment_df)\n",
    "    print(f'{experiment_name} solved {num_products_solved_by_this_experiment} products only')\n",
    "print(f'solved {len(products_solved)}/{len(products_no_match)}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "id": "cefcb4fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_solved_df = pd.concat(all_experiment_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "id": "0becdc61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'total_samples': 10661,\n",
       " 'total_products': 272,\n",
       " 'avg_samples_per_product': 39.19485294117647,\n",
       " 'sample_exact_match_accuracy': 0.025607353906762966,\n",
       " 'products_with_exact_match': 272,\n",
       " 'percentage_products_with_exact_match': 1.0,\n",
       " 'sample_class_accuracy': 0.3674139386549104,\n",
       " 'products_with_class_correct_samples': 272,\n",
       " 'percentage_products_with_class_correct': 1.0,\n",
       " 'avg_class_correct_samples_per_product': 14.400735294117647,\n",
       " 'sample_rxn_name_accuracy': 0.4704999531000844,\n",
       " 'products_with_rxn_name_correct_samples': 272,\n",
       " 'percentage_products_with_rxn_name_correct': 1.0,\n",
       " 'avg_rxn_name_correct_samples_per_product': 18.441176470588236,\n",
       " 'sample_round_trip_accuracy': 0.3645999437201013,\n",
       " 'products_with_round_trip_correct_samples': 272,\n",
       " 'percentage_products_with_round_trip_correct': 1.0,\n",
       " 'avg_round_trip_correct_samples_per_product': 14.290441176470589,\n",
       " 'avg_tanimoto_to_starting': 0.6204869822238311,\n",
       " 'max_tanimoto_to_starting': 1.0,\n",
       " 'avg_tanimoto_to_target': 0.7152001135723808,\n",
       " 'max_tanimoto_to_target': 1.0,\n",
       " 'avg_topk': {1: 0.2426470588235294,\n",
       "  3: 0.49264705882352944,\n",
       "  5: 0.6213235294117647,\n",
       "  10: 0.7610294117647058,\n",
       "  50: 1.0,\n",
       "  100: 1.0036764705882353},\n",
       " 'avg_coverage': {1: 0.6691176470588235,\n",
       "  3: 0.7977941176470589,\n",
       "  5: 0.8492647058823529,\n",
       "  10: 0.9191176470588235}}"
      ]
     },
     "execution_count": 177,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "guided_metrics = _calculate_per_experiment_metrics(experiment_solved_df)\n",
    "guided_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "id": "b6bb0bf1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'total_samples': 5142,\n",
       " 'total_products': 272,\n",
       " 'avg_samples_per_product': 18.904411764705884,\n",
       " 'sample_exact_match_accuracy': 0.0,\n",
       " 'products_with_exact_match': 0,\n",
       " 'percentage_products_with_exact_match': 0.0,\n",
       " 'sample_class_accuracy': 0.25009723842862697,\n",
       " 'products_with_class_correct_samples': 224,\n",
       " 'percentage_products_with_class_correct': 0.8235294117647058,\n",
       " 'avg_class_correct_samples_per_product': 5.741071428571429,\n",
       " 'sample_rxn_name_accuracy': 0.6353558926487748,\n",
       " 'products_with_rxn_name_correct_samples': 269,\n",
       " 'percentage_products_with_rxn_name_correct': 0.9889705882352942,\n",
       " 'avg_rxn_name_correct_samples_per_product': 12.144981412639405,\n",
       " 'sample_round_trip_accuracy': 0.4634383508362505,\n",
       " 'products_with_round_trip_correct_samples': 172,\n",
       " 'percentage_products_with_round_trip_correct': 0.6323529411764706,\n",
       " 'avg_round_trip_correct_samples_per_product': 13.854651162790697,\n",
       " 'avg_tanimoto_to_starting': 0.636574532084749,\n",
       " 'max_tanimoto_to_starting': 1.0,\n",
       " 'avg_tanimoto_to_target': 0.8192727116602525,\n",
       " 'max_tanimoto_to_target': 1.0,\n",
       " 'avg_topk': {1: 0.0, 3: 0.0, 5: 0.0, 10: 0.0, 50: 0.0, 100: 0.0},\n",
       " 'avg_coverage': {1: 0.5183823529411765,\n",
       "  3: 0.6176470588235294,\n",
       "  5: 0.6323529411764706,\n",
       "  10: 0.6323529411764706}}"
      ]
     },
     "execution_count": 178,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recovered_unguided_df = unguided_dfs[0][unguided_dfs[0]['product_smi'].isin(experiment_solved_df['product_smi'].unique().tolist())]\n",
    "unguided_metrics = _calculate_per_experiment_metrics(recovered_unguided_df)\n",
    "unguided_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "id": "d943ed57",
   "metadata": {},
   "outputs": [],
   "source": [
    "def jaccard_similarity(pred1, pred2):\n",
    "    return len(set(pred1).intersection(set(pred2)))/len(set(pred1).union(set(pred2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "id": "2cf393c3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3495748135444032"
      ]
     },
     "execution_count": 167,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jaccard_scores = []\n",
    "\n",
    "for product_smi in recovered_unguided_df['product_smi'].unique():\n",
    "    # Get all predictions for this product from both dataframes\n",
    "    unguided_preds = recovered_unguided_df[recovered_unguided_df['product_smi'] == product_smi]['reactant_predictions'].tolist()\n",
    "    guided_preds = experiment_solved_df[experiment_solved_df['product_smi'] == product_smi]['reactant_predictions'].tolist()\n",
    "    \n",
    "    jaccard = jaccard_similarity(unguided_preds, guided_preds)\n",
    "    jaccard_scores.append(jaccard)\n",
    "\n",
    "avg_jaccard = np.mean(jaccard_scores)\n",
    "avg_jaccard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "416f5c58",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "product: C=CCCCC(=O)CCC(C)OC(=O)c1c(C=C)cc(OC)cc1OC\n",
      "true reactants: C=CCCCC(=O)CCC(C)OC(=O)c1c(OC)cc(OC)cc1OS(=O)(=O)C(F)(F)F.C=[CH][Sn]([CH2]CCC)([CH2]CCC)[CH2]CCC\n",
      "length of reactant predictions: 72\n",
      "first reactant prediction: C=CCCCC(=O)CCC(C)O.C=Cc1cc(OC)cc(OC)c1C(=O)O\n",
      "index of true reactants: 10\n",
      "length of unguided reactant predictions: 36\n",
      "unguided result: False\n",
      "unguided reactant predictions: ['C=CCCCC(=O)CCC(C)O.C=Cc1cc(OC)cc(OC)c1C(=O)O', 'C=CCCCC(=O)CCC(C)O.C=Cc1cc(OC)cc(OC)c1C(=O)Cl', 'C=CCCCC(=O)CCC(C)Cl.C=Cc1cc(OC)cc(OC)c1C(=O)O', 'C=CCCCC(O)CCC(C)OC(=O)c1c(C=C)cc(OC)cc1OC', 'C=CCCCC(=O)CCC(C)OC(=O)c1c(I)cc(OC)cc1OC.C=[CH][Sn]([CH2]CCC)([CH2]CCC)[CH2]CCC', 'C=CCCCC(=O)CCC(C)Cl.C=Cc1cc(OC)cc(OC)c1C(=O)[O-]', 'C=CCCCC(=O)CCC(C)Br.C=Cc1cc(OC)cc(OC)c1C(=O)O', 'C=CCC[CH2][Mg+].C=Cc1cc(OC)cc(OC)c1C(=O)OC(C)CCC(=O)N(C)OC', 'C=CCCCC(=O)CCC(C)OC(=O)c1c(C=C)cc(OC)cc1OC.CI', 'C=CCCCC(=O)O.C=Cc1cc(OC)cc(OC)c1C(=O)Cl', 'C=CCCCC(=O)CCC(C)Br.C=Cc1cc(OC)cc(OC)c1C(=O)[O-]', 'C=CCCCC(=O)CCC(C)OC(=O)c1c(Br)cc(OC)cc1OC.C=[CH][Sn]([CH2]CCC)([CH2]CCC)[CH2]CCC', 'C=CCCCC(=O)O.C=Cc1cc(OC)cc(OC)c1C(=O)O', 'C=CCCCC(=O)CCC(C)OC(=O)c1c(C=C)cc(OC)cc1OC', 'C=CCCCC(=O)CC[C@H](C)O.C=Cc1cc(OC)cc(OC)c1C(=O)O', 'C=CCCCC(=O)CCC=C.C=Cc1cc(OC)cc(OC)c1C(=O)O', 'C=CCCC(=O)CCC(C)O.C=Cc1cc(OC)cc(OC)c1C(=O)O', 'C=CCCCC(=O)CC=C(C)OC(=O)c1c(C=C)cc(OC)cc1OC', 'C=CCCCC(=O)CC[C@@H](C)O.C=Cc1cc(OC)cc(OC)c1C(=O)O', 'C=CCCCC(=O)CCC(C)OC(=O)c1c(Cl)cc(OC)cc1OC.C=[CH][Sn]([CH2]CCC)([CH2]CCC)[CH2]CCC', 'C=CCCCC(=O)CCC(C)OS(C)(=O)=O.C=Cc1cc(OC)cc(OC)c1C(=O)O', 'C=CCCCBr.C=Cc1cc(OC)cc(OC)c1C(=O)OC(C)CCC(=O)N(C)OC', 'C=C.C=CCCCC(=O)CCC(C)OC(=O)c1c(Br)cc(OC)cc1OC', 'C=CCCCC(=O)C=CC(C)OC(=O)c1c(C=C)cc(OC)cc1OC', 'C=CCCCC(=O)CCC(C)[O-].C=Cc1cc(OC)cc(OC)c1C(=O)Cl', 'C=CCCC(=O)CCC(C)OC(=O)c1c(I)cc(OC)cc1OC.C=[CH][Sn]([CH2]CCC)([CH2]CCC)[CH2]CCC', 'C=CCCCC(=O)CCC(C)OC(=O)c1c(Br)cc(OC)cc1OC.C=C[Si](C)(C)C', 'C=CCCCC(=O)CCC(=C)OC(=O)c1c(C=C)cc(OC)cc1OC', 'C=CCCCC(=O)CCC(C)O.C=Cc1cc(OC)cc(O)c1C(=O)O', 'C=CCCCC(=O)CCC(C)OC(=O)c1c(Br)cc(OC)cc1OC.C=[CH][Cu]', 'C=CC[CH2][Mg+].C=Cc1cc(OC)cc(OC)c1C(=O)OC(C)CCC(=O)N(C)OC', 'C=C.C=CCCCC(=O)CCC(C)OC(=O)c1c(C=C)cc(OC)cc1OC', 'C=CCCCC(=O)CCC(C)O.C=Cc1cc(OC)cc(O)c1C(=O)Cl', 'C=CCCCC(=O)CCC(C)OC(=O)c1c(Br)cc(OC)cc1OC.C=[CH][Sn]([CH3])([CH3])[CH3]', 'C=CCCCC(=O)CCC(C)OC(=O)c1c(OC)cc(OC)cc1OC.CI', 'C=CCCBr.C=Cc1cc(OC)cc(OC)c1C(=O)OC(C)CCC(=O)N(C)OC']\n",
      "same rank as true in guided: C=CCCCC(=O)CCC(C)Br.C=Cc1cc(OC)cc(OC)c1C(=O)[O-]\n"
     ]
    }
   ],
   "source": [
    "guided_df = list_dfs[1]\n",
    "print(f'product: {products_solved[0]}')\n",
    "df_product = guided_df[guided_df['product_smi'] == products_solved[0]]\n",
    "true_reactants = df_product['true_reactants'].tolist()[0]\n",
    "print(f'true reactants: {true_reactants}')\n",
    "reactant_predictions = df_product['reactant_predictions'].tolist()\n",
    "print(f'length of reactant predictions: {len(reactant_predictions)}')\n",
    "print(f'first reactant prediction: {reactant_predictions[0]}')\n",
    "print(f'index of true reactants: {reactant_predictions.index(true_reactants)}')\n",
    "unguided_result = unguided_df_1[unguided_df_1['product_smi'] == products_solved[0]]['topk'].any()\n",
    "unguided_reactant_predictions = unguided_df_1[unguided_df_1['product_smi'] == products_solved[0]]['reactant_predictions'].tolist()\n",
    "print(f'length of unguided reactant predictions: {len(unguided_reactant_predictions)}')\n",
    "print(f'unguided result: {unguided_result}')\n",
    "print(f'unguided reactant predictions: {unguided_reactant_predictions}')\n",
    "print(f'same rank as true in guided: {unguided_reactant_predictions[reactant_predictions.index(true_reactants)]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d60b33f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "syntheseus-in-python10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
