'''
This file contains helper functions for the multiguide package.
'''
import os
import time
import json
import pickle
import json
import re
from itertools import combinations
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union, Any, Optional
import sys
from unittest.mock import MagicMock
# Mock the problematic RDKit drawing modules
sys.modules['rdkit.Chem.Draw.rdMolDraw2D'] = MagicMock()
sys.modules['rdkit.Chem.Draw'] = MagicMock()

from rdkit import Chem
import numpy as np
import pandas as pd
import torch
from torch.nn import functional as F

from rxn_insight.reaction import Reaction

from syntheseus import Molecule
from syntheseus.search.analysis.route_extraction import (
    iter_routes_time_order,
)
from syntheseus.reaction_prediction.inference import RootAlignedModel, RetroKNNModel, Graph2EditsModel, MEGANModel, LocalRetroModel
from syntheseus.reaction_prediction.inference import MHNreactModel, GLNModel, ChemformerModel
from multiguide.syntheseus.single_step_models.neural_sym import NeuralSymPredictor
from multiguide.syntheseus.single_step_models.root_aligned_fixed import RootAlignedFixedModel
from multiguide.helpers import PROJECT_ROOT
from multiguide.onmt.guided_generator import get_vocab_from_trained_model
from multiguide.dataset.helpers import turn_seq_to_ids, turn_results_to_mol_smiles, get_similarity, \
                                        compare_reactant_smiles, class_to_idx, get_tanimoto
from multiguide.property.property_predictor import PropertyPredictor
from multiguide.syntheseus.single_step_models.root_aligned_forward import RootAlignedForwardModel
from syntheseus.search.analysis import diversity
from multiguide.dataset.helpers import remove_dative_bonds, class_to_idx, get_tanimoto, remove_dative_bonds_one_molecule

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def aggregate_guided_search_results_with_selection(
    experiment_info: List[Dict[str, str]],
    project_root: str,
    experiment_dir: str,
    selection_criteria: str = 'reaction_type',
    return_all_info: bool = False
) -> List[Dict]:
    """
    Aggregate search results from multiple guided experiments, selecting best per product.
    Follows the same pattern as single-step aggregation.
    
    Args:
        experiment_info: List of dicts with keys:
            - 'experiment_regex': regex pattern to match experiment directories
            - 'method_name': display name for the method
            - 'experiment_group': experiment group directory name
            - Additional optional filter keys (e.g., 'contains', 'guidance_scale', etc.)
        project_root: Root directory of the project
        experiment_dir: Base experiment directory (e.g., 'experiments/search/retro_star')
        
    Returns:
        List of dicts with aggregated metrics for each method
    """
    aggregates = []
    target_dfs = []
    all_guided_data = []
    for exp in experiment_info:
        experiment_regex = exp['experiment_regex']
        method_name = exp['method_name']
        experiment_group = exp['experiment_group']
        # Build experiment filters
        experiment_filters = {'experiment_regex': experiment_regex}
        
        # # Add any additional filters from exp dict
        # for key in ['guidance_scale', 'min_length', 'renorm', 'time_stamp', 
        #             'time_regex', 'contains', 'not_contains', 'steered', 
        #             'guided', 'filtered']:
        #     if key in exp:
        #         experiment_filters[key] = exp[key]
        
        # Load all matching experiments
        # project_root: str, 
        # experiment_dir: str, 
        # experiment_group: str, 
        # experiment_filters: Dict = None,
        # reaction_steps: List[int] = None,
        # experiment_subdir: str = ''
        results = load_experiment_results(
            project_root, 
            experiment_dir, 
            experiment_group, 
            experiment_filters,
            experiment_subdir='strategy_None/evaluations'
        )
        
        if not results:
            print(f"Warning: No results found for {method_name}")
            continue
        
        # Select best experiment per product across all matching experiments
        guided_data, guided_experiments = select_best_search_experiment_per_product(
            list_dfs=list(results.values()), 
            list_experiment_names=list(results.keys()),
            criteria=selection_criteria
        )
        all_guided_data.append(guided_data)
        # Calculate metrics on best data
        dataset_stats, target_df = calculate_dataset_aggregates(guided_data)
        dataset_stats['method'] = method_name
        target_dfs.append(target_df)
        aggregates.append(dataset_stats)
        #print(f'dataset_stats: {dataset_stats.keys()}')
    if return_all_info:
        return aggregates, target_dfs, all_guided_data
    else:
        return aggregates


def get_search_target_metrics_table_tanimoto(
    experiment_info, 
    aggregates,
    table_name,
    save_table=True,
    caption='Search target metrics',
    highlight_per_group=True,
    highlight_methods=None,
):
    '''
        Get latex table for search target metrics.

        Args:
            experiment_info: List of dictionaries containing experiment information.
            aggregates: Aggregated search results.
            table_name: Name of the table to save.
            save_table: Whether to save the table.

        Returns:
            latex_table: Latex table.
    '''
    method_names = []
    method_categories = {}
    method_groups = {}
    
    for i, e in enumerate(experiment_info):
        # Use index as part of key to ensure uniqueness
        unique_key = f"{e['method_name']}_{i}"
        method_names.append(unique_key)
        method_categories[unique_key] = e['category']
        method_groups[unique_key] = e['trained_on']
    
    # Update the aggregates to use the same unique keys
    updated_aggregates = []
    for i, agg in enumerate(aggregates):
        e = experiment_info[i]
        unique_key = f"{e['method_name']}_{i}"
        updated_agg = {
            **agg,
            'method': unique_key,
            'display_name': e['method_name']  # Keep original name for display
        }
        updated_aggregates.append(updated_agg)

    label = "tab:search-dataset-metrics"
    metric_display_names = {
        'avg_targets_with_exact_match_route': r'Exact',
        'avg_targets_with_round_trip_route': r'Round trip',
        'target_avg_rxn_type_match': r'Reaction',
        'avg_targets_with_rxn_name_match_route': r'Reaction',
        'num_routes_with_sm': r'Perc. of routes',
        #'avg_route_length_diff': 'Route length diff',
        #'avg_targets_with_exact_length_match': 'Exact length match',
    }
    metric_display_names_line2 = {
        'avg_targets_with_exact_match_route': r'route ($\uparrow$)',
        'avg_targets_with_round_trip_route': r'route ($\uparrow$)',
        'target_avg_rxn_type_match': r'type ($\uparrow$)',
        'avg_targets_with_rxn_name_match_route': r'name ($\uparrow$)',
        'num_routes_with_sm': r'with SM per target ($\uparrow$)',
    }
    bold_best = {
        'avg_targets_with_exact_match_route': 'high',
        'avg_targets_with_round_trip_route': 'high',
        'target_avg_rxn_type_match': 'high',
        'avg_targets_with_rxn_name_match_route': 'high',
        #'avg_route_length_diff': 'low',
        #'avg_targets_with_exact_length_match': 'high',
        'num_routes_with_sm': 'high',
    }
    metrics = [
        'avg_targets_with_exact_match_route',
        'avg_targets_with_round_trip_route',
        'target_avg_rxn_type_match',
        'avg_targets_with_rxn_name_match_route',
        'num_routes_with_sm',
        #'avg_route_length_diff',
       #'avg_targets_with_exact_length_match',
    ]

    # Or pass directly to your latex table generator
    # latex_table = generate_latex_table(
    #     experiment_dirs=experiment_dirs,  # Not needed
    #     metrics=metrics,
    #     method_names=method_names,
    #     method_categories=method_categories,
    #     caption=caption,
    #     label=label,
    #     bold_best=bold_best,
    #     metric_display_names=metric_display_names,
    #     metric_display_names_line2=metric_display_names_line2,
    #     results=aggregates,
    #     decimal_places=2
    # )
    latex_table = generate_latex_table_manual_synthesis(
        experiment_dirs=experiment_info,
        metrics=metrics,
        method_names=method_names,
        caption=caption,
        label=label,
        metric_display_names=metric_display_names,
        metric_display_names_line2=metric_display_names_line2,
        bold_best=bold_best,
        results=updated_aggregates,
        method_categories=method_categories,
        method_groups=method_groups,
        use_siunitx=False,
        font_size="small",
        tabcolsep="4pt",
        group_header_spacing="2pt",
        group_separation="4pt",
        highlight_per_group=highlight_per_group,
        highlight_methods=highlight_methods,
        highlight_color="highlightgreen",
    )

    if save_table:
        latex_output_path = os.path.join(
            PROJECT_ROOT,
            'paper',
            'iclr2026',
            'tables',
            table_name
        )
        save_latex_table(
            latex_table=latex_table,
            output_path=latex_output_path,
            standalone=False
        )
    return latex_table

def get_search_target_metrics_table_reaction_type(
    experiment_info,
    aggregates,
    table_name,
    save_table=True,
    caption='Search target metrics'
):
    '''
    Get latex table for search target metrics.  

    Args:
        experiment_info: List of dictionaries containing experiment information.
        aggregates: Aggregated search results.
        table_name: Name of the table to save.
        save_table: Whether to save the table.
    
    Returns:
        latex_table: Latex table.
    '''
    method_names = [exp['method_name'] for exp in experiment_info]
    experiment_dirs = [exp['experiment_regex'] for exp in experiment_info]
    method_categories = {name: exp['category'] for name, exp in zip(method_names, experiment_info)}
    label = "tab:search-target-metrics"
    metric_display_names = {
        'target_avg_exact_match': 'Avg. match',
        'target_avg_round_trip': 'Avg. RT',
        'target_avg_rxn_type_match': 'Avg. rxn type match',
        # 'target_avg_tanimoto_to_target': 'Avg. TA to target',
        # 'target_avg_tanimoto_to_sm': 'Avg. TA to sm',
        'target_avg_rxn_name_match': 'Avg. name',
    }
    metric_display_names_line2 = {
        'avg_targets_with_exact_match_route': r'match ($\uparrow$)',
        'target_avg_round_trip': r'RT ($\uparrow$)',
        'target_avg_rxn_type_match': r'type ($\uparrow$)',
        'target_avg_rxn_name_match': r'name ($\uparrow$)',
    }
    bold_best = {
        'target_avg_exact_match': 'high',
        'target_avg_round_trip': 'high',
        'target_avg_rxn_type_match': 'high',
        # 'target_avg_tanimoto_to_target': 'high',
        # 'target_avg_tanimoto_to_sm': 'high',
        'target_avg_rxn_name_match': 'high',
    }
    metrics = [
        'target_avg_exact_match',
        'target_avg_round_trip',
        'target_avg_rxn_type_match',
        # 'target_avg_tanimoto_to_target',
        # 'target_avg_tanimoto_to_sm',
        'target_avg_rxn_name_match',
    ]

    # Or pass directly to your latex table generator
    latex_table = generate_latex_table(
        experiment_dirs=experiment_dirs,  # Not needed
        metrics=metrics,
        method_names=method_names,
        method_categories=method_categories,
        caption=caption,
        label=label,
        bold_best=bold_best,
        metric_display_names=metric_display_names,
        results=aggregates,
        decimal_places=2
    )
    if save_table:
        latex_output_path = os.path.join(
            PROJECT_ROOT,
            'paper',
            'iclr2026',
            'tables',
            table_name
        )
        save_latex_table(
            latex_table=latex_table,
            output_path=latex_output_path,
            standalone=False
        )
    return latex_table

def get_single_step_metrics_table_reaction_type(
    experiment_info,
    aggregates,
    table_name,
    save_table=True,
    caption='Manual synthesis metrics',
    label='tab:manual-synthesis-metrics'
):
    '''
    Get latex table for manual synthesis metrics.
    '''
    method_names = [e['method_name'] for e in experiment_info]
    method_categories = {e['method_name']: e['category'] for e in experiment_info}
    method_groups = {e['method_name']: e['trained_on'] for e in experiment_info}
    print(method_categories)
    metric_display_names = {
        'completion_rate': r'Exact',
        'avg_topk_1': r'top-1',
        #'avg_topk_3': r'top-3',
        'avg_topk_5': r'top-5',
        'avg_topk_100': r'top-100',
        'perc_samples_per_product': r'Unique',
        'percentage_products_with_class_correct': r'Correct',
        #'avg_tanimoto_to_starting': r'TA',
        # 'products_with_max_tanimoto_to_starting': r'TA',
        #'percentage_products_with_rxn_name_correct': r'Correct',
        'percentage_products_with_round_trip_correct': r'Correct',
    }
    metric_display_names_line2 = {
        'completion_rate': r'route ($\uparrow$)',
        'avg_topk_1': r'($\uparrow$)',
        #'avg_topk_3': r'($\uparrow$)',
        'avg_topk_5': r'($\uparrow$)',
        'avg_topk_100': r'($\uparrow$)',
        'perc_samples_per_product': r'samples ($\uparrow$)',
        'percentage_products_with_class_correct': r'class ($\uparrow$)',
        #'avg_tanimoto_to_starting': r'to SM ($\uparrow$)',
        # 'products_with_max_tanimoto_to_starting': r'to SM ($\uparrow$)',
        #'percentage_products_with_rxn_name_correct': r'name ($\uparrow$)',
        'percentage_products_with_round_trip_correct': r'RT ($\uparrow$)',
    }
    metrics = metric_display_names.keys()
    bold_best = {
        'completion_rate': 'high',
        'avg_topk_1': 'high',
        #'avg_topk_3': 'high',
        'avg_topk_5': 'high',
        'avg_topk_100': 'high',
        'perc_samples_per_product': 'high',
        'percentage_products_with_class_correct': 'high',
        #'avg_tanimoto_to_starting': 'high',
        # 'products_with_max_tanimoto_to_starting': 'high',
        #'percentage_products_with_rxn_name_correct': 'high',
        'percentage_products_with_round_trip_correct': 'high',
    }
    latex_table = generate_latex_table_manual_synthesis(
        experiment_dirs=experiment_info,
        metrics=metrics,
        results=aggregates,
        method_names=method_names,
        method_categories=method_categories,
        method_groups=method_groups,
        caption=caption,
        label=label,
        metric_display_names=metric_display_names,
        metric_display_names_line2=metric_display_names_line2,
        bold_best=bold_best,
        decimal_places=2
    )

    latex_output_path = os.path.join(
        PROJECT_ROOT,
        'paper',
        'iclr2026',
        'tables',
        table_name
    )
    if save_table:
        save_latex_table(
            latex_table=latex_table,
            output_path=latex_output_path,
            standalone=False
        )
    return latex_table

# def get_manual_synthesis_metrics_table_reaction_type(
#     experiment_info,
#     aggregates,
#     table_name,
#     save_table=True,
#     caption='Manual synthesis metrics',
#     label='tab:manual-synthesis-metrics'
# ):
#     '''
#     Get latex table for manual synthesis metrics.
#     '''
#     method_names = [e['method_name'] for e in experiment_info]
#     method_categories = {e['method_name']: e['category'] for e in experiment_info}
#     method_groups = {e['method_name']: e['trained_on'] for e in experiment_info}
#     print(method_categories)
#     metric_display_names = {
#         'completion_rate': r'Exact',
#         'avg_topk_1': r'top-1',
#         #'avg_topk_3': r'top-3',
#         'avg_topk_5': r'top-5',
#         'avg_topk_100': r'top-100',
#         'perc_samples_per_product': r'Unique',
#         'percentage_products_with_class_correct': r'Correct',
#         #'avg_tanimoto_to_starting': r'TA',
#         'products_with_max_tanimoto_to_starting': r'TA',
#         #'percentage_products_with_rxn_name_correct': r'Correct',
#         'percentage_products_with_round_trip_correct': r'Correct',
#     }
#     metric_display_names_line2 = {
#         'completion_rate': r'route ($\uparrow$)',
#         'avg_topk_1': r'($\uparrow$)',
#         #'avg_topk_3': r'($\uparrow$)',
#         'avg_topk_5': r'($\uparrow$)',
#         'avg_topk_100': r'($\uparrow$)',
#         'perc_samples_per_product': r'samples ($\uparrow$)',
#         'percentage_products_with_class_correct': r'class ($\uparrow$)',
#         #'avg_tanimoto_to_starting': r'to SM ($\uparrow$)',
#         'products_with_max_tanimoto_to_starting': r'to SM ($\uparrow$)',
#         #'percentage_products_with_rxn_name_correct': r'name ($\uparrow$)',
#         'percentage_products_with_round_trip_correct': r'RT ($\uparrow$)',
#     }
#     metrics = metric_display_names.keys()
#     bold_best = {
#         'completion_rate': 'high',
#         'avg_topk_1': 'high',
#         #'avg_topk_3': 'high',
#         'avg_topk_5': 'high',
#         'avg_topk_100': 'high',
#         'perc_samples_per_product': 'high',
#         'percentage_products_with_class_correct': 'high',
#         #'avg_tanimoto_to_starting': 'high',
#         'products_with_max_tanimoto_to_starting': 'high',
#         #'percentage_products_with_rxn_name_correct': 'high',
#         'percentage_products_with_round_trip_correct': 'high',
#     }
#     latex_table = generate_latex_table_manual_synthesis(
#         experiment_dirs=experiment_info,
#         metrics=metrics,
#         results=aggregates,
#         method_names=method_names,
#         method_categories=method_categories,
#         method_groups=method_groups,
#         caption=caption,
#         label=label,
#         metric_display_names=metric_display_names,
#         metric_display_names_line2=metric_display_names_line2,
#         bold_best=bold_best,
#         decimal_places=2,
#         #highlight_methods=['Rsmiles-TG$_{\\text{rxn}}$', 'Rsmiles-TG$_{\\text{sim}}$'],
#     )

#     latex_output_path = os.path.join(
#         PROJECT_ROOT,
#         'paper',
#         'iclr2026',
#         'tables',
#         table_name
#     )
#     if save_table:
#         save_latex_table(
#             latex_table=latex_table,
#             output_path=latex_output_path,
#             standalone=False
#         )
#     return latex_table

def get_search_metrics_table_reaction_type(
    experiment_info, 
    aggregates, 
    table_name,
    save_table=True,
    caption='Search metrics',
    label='tab:search-metrics',
    highlight_per_group=True
):
    '''
        Get latex table for search metrics.

        Args:
            experiment_info: List of dictionaries containing experiment information.
            aggregates: Aggregated search results.
            table_name: Name of the table to save.
        
        Returns:
            latex_table: Latex table.
    '''
    # method_names = [exp['method_name'] for exp in experiment_info]
    # experiment_dirs = [exp['experiment_regex'] for exp in experiment_info]
    # method_categories = {name: exp['category'] for name, exp in zip(method_names, experiment_info)}

    # Create unique identifiers for each experiment using index
    method_names = []
    method_categories = {}
    method_groups = {}
    
    for i, e in enumerate(experiment_info):
        # Use index as part of key to ensure uniqueness
        unique_key = f"{e['method_name']}_{i}"
        method_names.append(unique_key)
        method_categories[unique_key] = e['category']
        method_groups[unique_key] = e['trained_on']
    
    # Update the aggregates to use the same unique keys
    updated_aggregates = []
    for i, agg in enumerate(aggregates):
        e = experiment_info[i]
        unique_key = f"{e['method_name']}_{i}"
        updated_agg = {
            **agg,
            'method': unique_key,
            'display_name': e['method_name']  # Keep original name for display
        }
        updated_aggregates.append(updated_agg)

    bold_best = {
        'solve_rate': 'high',           # Higher is better
        'avg_nodes_explored': 'low',    # Lower is better
        'avg_model_calls': 'low',       # Lower is better
        'avg_time_taken': 'low',        # Lower is better
        'avg_num_nonoverlapping_routes': 'high' # Higher is better
    }
    metric_display_names = {
        'solve_rate': r'Solve',
        'avg_nodes_explored': r'Nodes',
        'avg_model_calls': r'Model',
        'avg_time_taken': r'Time',
        'avg_num_nonoverlapping_routes': r'Non-overlapping',
    }
    metric_display_names_line2 = {
        'solve_rate': r'rate ($\uparrow$)',
        'avg_nodes_explored': r'explored ($\downarrow$)',
        'avg_model_calls': r'calls ($\downarrow$)',
        'avg_time_taken': r'taken ($\downarrow$)',
        'avg_num_nonoverlapping_routes': r'routes per target ($\uparrow$)',
    }
    metrics = [
        'solve_rate',
        'avg_nodes_explored',
        'avg_model_calls',
        'avg_time_taken',
        'avg_num_nonoverlapping_routes',
    ]

    # Or pass directly to your latex table generator
    # latex_table = generate_latex_table(
    #     experiment_dirs=experiment_dirs,  # Not needed
    #     metrics=metrics,
    #     method_names=method_names,
    #     method_categories=method_categories,
    #     caption=caption,
    #     label=label,
    #     bold_best=bold_best,
    #     metric_display_names=metric_display_names,
    #     metric_display_names_line2=metric_display_names_line2,
    #     results=aggregates,
    #     decimal_places=2
    # )
    latex_table = generate_latex_table_manual_synthesis(
        experiment_dirs=experiment_info,
        metrics=metrics,
        method_names=method_names,
        caption=caption,
        label=label,
        metric_display_names=metric_display_names,
        metric_display_names_line2=metric_display_names_line2,
        bold_best=bold_best,
        results=updated_aggregates,
        method_categories=method_categories,
        method_groups=method_groups,
        use_siunitx=False,
        font_size="small",
        tabcolsep="4pt",
        group_header_spacing="2pt",
        group_separation="4pt",
        highlight_per_group=highlight_per_group,
        highlight_methods=None,
        highlight_color="highlightgreen",
    )

    if save_table:
        latex_output_path = os.path.join(
            PROJECT_ROOT,
            'paper',
            'iclr2026',
            'tables',
            table_name
        )
        save_latex_table(
            latex_table=latex_table,
            output_path=latex_output_path,
            standalone=False
        )
    return latex_table


def get_search_metrics_table_tanimoto(
    experiment_info, 
    aggregates, 
    table_name,
    save_table=True,
    caption='Search metrics',
    label='tab:search-metrics',
    highlight_per_group=True,
    highlight_methods=None,
):
    '''
        Get latex table for search metrics.

        Args:
            experiment_info: List of dictionaries containing experiment information.
            aggregates: Aggregated search results.
            table_name: Name of the table to save.
            caption: Table caption.
            label: Table label.
        Returns:
            latex_table: Latex table.
    '''
    # Create unique identifiers for each experiment using index
    method_names = []
    method_categories = {}
    method_groups = {}
    
    for i, e in enumerate(experiment_info):
        # Use index as part of key to ensure uniqueness
        unique_key = f"{e['method_name']}_{i}"
        method_names.append(unique_key)
        method_categories[unique_key] = e['category']
        method_groups[unique_key] = e['trained_on']
    
    # Update the aggregates to use the same unique keys
    updated_aggregates = []
    for i, agg in enumerate(aggregates):
        e = experiment_info[i]
        unique_key = f"{e['method_name']}_{i}"
        updated_agg = {
            **agg,
            'method': unique_key,
            'display_name': e['method_name']  # Keep original name for display
        }
        updated_aggregates.append(updated_agg)

    bold_best = {
        'solve_rate': 'high',           # Higher is better
        'solve_rate_with_sm': 'high',   # Higher is better
        #'avg_nodes_explored': 'low',    # Lower is better
        'avg_model_calls': 'low',       # Lower is better
        'avg_time_taken': 'low',        # Lower is better
        'avg_num_nonoverlapping_routes': 'high' # Higher is better
    }
    metric_display_names = {
        'solve_rate': r'Solve',
        'solve_rate_with_sm': r'Solve',
        #'avg_nodes_explored': r'Nodes',
        'avg_model_calls': r'Model',
        'avg_time_taken': r'Time',
        'avg_num_nonoverlapping_routes': r'Non-overlapping',
    }
    metrics = [
        'solve_rate',
        'solve_rate_with_sm',
        #'avg_nodes_explored',
        'avg_model_calls',
        'avg_time_taken',
        'avg_num_nonoverlapping_routes',
    ]
    metric_display_names_line2 = {
        'solve_rate': r'rate ($\uparrow$)',
        'solve_rate_with_sm': r'rate with SM ($\uparrow$)',
        #'avg_nodes_explored': r'explored ($\downarrow$)',
        'avg_model_calls': r'calls ($\downarrow$)',
        'avg_time_taken': r'taken ($\downarrow$)',
        'avg_num_nonoverlapping_routes': r'routes per target ($\uparrow$)',
    }

    # Or pass directly to your latex table generator
    # latex_table = generate_latex_table(
    #     experiment_dirs=experiment_dirs,
    #     metrics=metrics,
    #     method_names=method_names,
    #     method_categories=method_categories,
    #     caption=caption,
    #     label=label,
    #     metric_display_names=metric_display_names,
    #     metric_display_names_line2=metric_display_names_line2,
    #     results=aggregates,
    #     decimal_places=2,
    #     bold_best=bold_best,
    #     use_siunitx=False  # Set to True if you have siunitx package
    # )

    latex_table = generate_latex_table_manual_synthesis(
        experiment_dirs=experiment_info,
        metrics=metrics,
        method_names=method_names,
        caption=caption,
        label=label,
        metric_display_names=metric_display_names,
        metric_display_names_line2=metric_display_names_line2,
        bold_best=bold_best,
        results=updated_aggregates,
        method_categories=method_categories,
        method_groups=method_groups,
        use_siunitx=False,
        font_size="small",
        tabcolsep="4pt",
        group_header_spacing="2pt",
        group_separation="4pt",
        highlight_per_group=highlight_per_group,
        highlight_methods=highlight_methods,
        highlight_color="highlightgreen",
    )
    if save_table:
        latex_output_path = os.path.join(
            PROJECT_ROOT,
            'paper',
            'iclr2026',
            'tables',
            table_name
        )
        save_latex_table(
            latex_table=latex_table,
            output_path=latex_output_path,
            standalone=False
        )
    return latex_table

def aggregate_with_filter_manual_synthesis_results(
    experiment_info: List[Dict[str, str]],
    project_root: str,
    experiment_dir: str,
    true_routes_path: str,
    return_all_info: bool = False,
    total_samples_per_product: int = 100,
    apply_filter: bool = False
) -> List[Dict]:
    """
    Aggregate manual synthesis results from multiple experiments.
    """
    aggregates = []
    target_dfs = []
    all_guided_data = []
    for exp in experiment_info:
        experiment_regex = exp['experiment_regex']
        method_name = exp['method_name']
        experiment_group = exp['experiment_group']
        selection_criteria = exp['criteria']
        criteria_threshold = exp['criteria_threshold'] if 'criteria_threshold' in exp else None
        experiment_filters = {'experiment_regex': experiment_regex}
        combination = exp['combination'] if 'combination' in exp else 'filter'
        results = load_experiment_results(project_root, experiment_dir, experiment_group, experiment_filters)
        # merge all dataframes
        # filter based on selection_criteria
        if combination == 'select':
            guided_data, guided_experiments = select_best_experiment_manual_synthesis_per_product(
                list_dfs=results.values(), 
                list_experiment_names=results.keys(),
                criteria=selection_criteria
            )
        elif combination == 'filter':
            guided_data = filter_manual_synthesis_results_by_criteria(
                results.values(),
                results.keys(),
                selection_criteria,
                criteria_threshold,
                apply_filter=apply_filter
            )
        else:   
            raise ValueError(f'Combination {combination} not supported')
        if combination == 'select':
            total_samples_per_product = 100
        elif combination == 'filter':
            total_samples_per_product = 600
        guided_quality_metrics = _calculate_per_experiment_metrics(guided_data, total_samples_per_product=total_samples_per_product)
        guided_quality_metrics = simplify_metrics(guided_quality_metrics)
        with open(true_routes_path, 'r') as f:
            true_routes = json.load(f)
        route_completion = calculate_route_completion_rates(
            results, true_routes, use_starting_material=False, max_steps=100
        )
        guided_quality_metrics['completion_rate'] = route_completion['mixed_param_completion']['completion_rate']
        guided_quality_metrics['method'] = method_name
        guided_quality_metrics['route_completion'] = route_completion
        aggregates.append(guided_quality_metrics)
        target_dfs.append(guided_data)
        all_guided_data.append(guided_data)
    if return_all_info:
        return aggregates, target_dfs, all_guided_data
    else:
        return aggregates
      

def aggregate_manual_synthesis_results(
    experiment_info: List[Dict[str, str]],
    project_root: str,
    experiment_dir: str,
    true_routes_path: str,
    return_all_info: bool = False
) -> List[Dict]:
    """
    Aggregate manual synthesis results from multiple experiments.
    """
    aggregates = []
    target_dfs = []
    all_guided_data = []
    for exp in experiment_info:
        experiment_regex = exp['experiment_regex']
        method_name = exp['method_name']
        experiment_group = exp['experiment_group']
        selection_criteria = exp['criteria']
        experiment_filters = {'experiment_regex': experiment_regex}
        results = load_experiment_results(project_root, experiment_dir, experiment_group, experiment_filters)
        guided_data, guided_experiments = select_best_experiment_manual_synthesis_per_product(
            list_dfs=results.values(), 
            list_experiment_names=results.keys(),
            criteria=selection_criteria
        )
        guided_quality_metrics = _calculate_per_experiment_metrics(guided_data)
        guided_quality_metrics = simplify_metrics(guided_quality_metrics)
        with open(true_routes_path, 'r') as f:
            true_routes = json.load(f)
        route_completion = calculate_route_completion_rates(
            results, true_routes, use_starting_material=False, max_steps=100
        )
        guided_quality_metrics['completion_rate'] = route_completion['mixed_param_completion']['completion_rate']
        guided_quality_metrics['method'] = method_name
        guided_quality_metrics['route_completion'] = route_completion
        aggregates.append(guided_quality_metrics)
        target_dfs.append(guided_data)
        all_guided_data.append(guided_data)
    if return_all_info:
        return aggregates, target_dfs, all_guided_data
    else:
        return aggregates
        
def aggregate_guided_search_results_with_selection(
    experiment_info: List[Dict[str, str]],
    project_root: str,
    experiment_dir: str,
    selection_criteria: str = 'reaction_type',
    return_all_info: bool = False
) -> List[Dict]:
    """
    Aggregate search results from multiple guided experiments, selecting best per product.
    Follows the same pattern as single-step aggregation.
    
    Args:
        experiment_info: List of dicts with keys:
            - 'experiment_regex': regex pattern to match experiment directories
            - 'method_name': display name for the method
            - 'experiment_group': experiment group directory name
            - Additional optional filter keys (e.g., 'contains', 'guidance_scale', etc.)
        project_root: Root directory of the project
        experiment_dir: Base experiment directory (e.g., 'experiments/search/retro_star')
        
    Returns:
        List of dicts with aggregated metrics for each method
    """
    aggregates = []
    target_dfs = []
    all_guided_data = []
    for exp in experiment_info:
        experiment_regex = exp['experiment_regex']
        method_name = exp['method_name']
        experiment_group = exp['experiment_group']
        
        # Build experiment filters
        experiment_filters = {'experiment_regex': experiment_regex}
        
        # # Add any additional filters from exp dict
        # for key in ['guidance_scale', 'min_length', 'renorm', 'time_stamp', 
        #             'time_regex', 'contains', 'not_contains', 'steered', 
        #             'guided', 'filtered']:
        #     if key in exp:
        #         experiment_filters[key] = exp[key]
        
        # Load all matching experiments
        # project_root: str, 
        # experiment_dir: str, 
        # experiment_group: str, 
        # experiment_filters: Dict = None,
        # reaction_steps: List[int] = None,
        # experiment_subdir: str = ''
        results = load_experiment_results(
            project_root, 
            experiment_dir, 
            experiment_group, 
            experiment_filters,
            experiment_subdir='strategy_None/evaluations'
        )
        
        if not results:
            print(f"Warning: No results found for {method_name}")
            continue
        
        # Select best experiment per product across all matching experiments
        guided_data, guided_experiments = select_best_search_experiment_per_product(
            list_dfs=list(results.values()), 
            list_experiment_names=list(results.keys()),
            criteria=selection_criteria
        )
        all_guided_data.append(guided_data)
        # Calculate metrics on best data
        dataset_stats, target_df = calculate_dataset_aggregates(guided_data)
        dataset_stats['method'] = method_name
        target_dfs.append(target_df)
        aggregates.append(dataset_stats)
        #print(f'dataset_stats: {dataset_stats.keys()}')
    if return_all_info:
        return aggregates, target_dfs, all_guided_data
    else:
        return aggregates

def jaccard_similarity(pred1, pred2):
    return len(set(pred1).intersection(set(pred2)))/len(set(pred1).union(set(pred2)))

def select_best_search_experiment_per_product(list_dfs, list_experiment_names, criteria: str = 'reaction_type'):
    if criteria == 'reaction_type':
        return select_best_search_experiment_per_product_reaction_type(list_dfs, list_experiment_names)
    elif criteria == 'tanimoto':
        return select_best_search_experiment_per_product_tanimoto(list_dfs, list_experiment_names)
    elif criteria=='oracle':
        return select_best_search_experiment_per_product_oracle(list_dfs, list_experiment_names)
    else:
        raise ValueError(f'Criteria {criteria} not supported')

def select_best_search_experiment_per_product_reaction_type(list_dfs, list_experiment_names):
    """
    For each product, select the search experiment with best results based on hierarchical criteria:
    4. Route with max number of exact rxn type matches
    5. Total number of exact rxn type matches across all routes
    6. A route was found (solved == True)
    
    Args:
        list_dfs: List of dataframes containing search results
        list_experiment_names: List of experiment names corresponding to dataframes
        
    Returns:
        best_data: DataFrame with best experiment results for each product
        best_experiments: DataFrame with metadata about which experiment was selected
    """
    # Combine all dataframes
    combined = []
    for df, experiment_name in zip(list_dfs, list_experiment_names):
        df_copy = df.copy()
        df_copy['experiment_name'] = experiment_name
        combined.append(df_copy)
    
    all_data = pd.concat(combined, ignore_index=True)
    
    # Compute metrics per product per experiment
    def compute_metrics(group):
        # Criterion 4: Maximum number of exact rxn type matches in a single route
        if 'pred_class' in group.columns and 'true_class' in group.columns and 'sample_route_idx' in group.columns:
            # For each route, count how many reactions have matching types
            route_match_counts = group.groupby('sample_route_idx').apply(
                lambda x: (x['pred_class'] == x['true_class']).sum(),
                include_groups=False
            )
            # Take the maximum across all routes
            max_rxn_type_matches_per_route = route_match_counts.max() if len(route_match_counts) > 0 else 0
        else:
            max_rxn_type_matches_per_route = 0
        
        # Criterion 5: Total number of exact rxn type matches across all routes
        if 'pred_class' in group.columns and 'true_class' in group.columns:
            num_exact_class_matches = (group['pred_class'] == group['true_class']).sum()
        else:
            num_exact_class_matches = 0
        
        # Criterion 6: Route was found
        route_found = group['solved'].any() if 'solved' in group.columns else False
        
        return pd.Series({
            'max_rxn_type_matches_per_route': max_rxn_type_matches_per_route,
            'num_exact_class_matches': num_exact_class_matches,
            'route_found': route_found
        })
    
    # Group by product and experiment, compute metrics
    metrics = all_data.groupby(
        ['target_idx', 'experiment_name'], 
        as_index=False
    ).apply(compute_metrics, include_groups=False)
    
    # Sort by criteria in order of preference
    # For boolean columns, convert to int so True > False
    metrics_sorted = metrics.sort_values(
        by=[
            'target_idx',
            'max_rxn_type_matches_per_route',
            'num_exact_class_matches',
            'route_found'
        ],
        ascending=[True, False, False, False]
    )
    
    # Take best experiment per product
    best_experiments = metrics_sorted.groupby('target_idx').first().reset_index()
    
    # Join back to get full data for best experiments
    best_data = all_data.merge(
        best_experiments[['target_idx', 'experiment_name']], 
        on=['target_idx', 'experiment_name']
    )
    
    return best_data, best_experiments

def select_best_search_experiment_per_product_oracle(list_dfs, list_experiment_names):
    """
    For each product, select the search experiment with best results based on hierarchical criteria:
    1. Exact route was found (product_matches == True)
    2. Route with max number of topk matches
    3. Total number of topk matches across all routes
    4. Route with max number of exact rxn type matches
    5. Total number of exact rxn type matches across all routes
    6. A route was found (solved == True)
    
    Args:
        list_dfs: List of dataframes containing search results
        list_experiment_names: List of experiment names corresponding to dataframes
        
    Returns:
        best_data: DataFrame with best experiment results for each product
        best_experiments: DataFrame with metadata about which experiment was selected
    """
    # Combine all dataframes
    combined = []
    for df, experiment_name in zip(list_dfs, list_experiment_names):
        df_copy = df.copy()
        df_copy['experiment_name'] = experiment_name
        combined.append(df_copy)
    
    all_data = pd.concat(combined, ignore_index=True)
    
    # Compute metrics per product per experiment
    def compute_metrics(group):
        # Criterion 1: Exact route found (all reactions in at least one route have topk=True)
        if 'topk' in group.columns and 'sample_route_idx' in group.columns:
            # Check if any route has all topk=True
            has_exact_route = group.groupby('sample_route_idx')['topk'].all().any()
        else:
            has_exact_route = False
        
        # Criterion 2: Maximum number of topk matches in a single route
        if 'topk' in group.columns and 'sample_route_idx' in group.columns:
            # For each route, count how many reactions have topk=True
            route_topk_counts = group.groupby('sample_route_idx')['topk'].sum()
            # Take the maximum across all routes
            max_topk_matches_per_route = route_topk_counts.max() if len(route_topk_counts) > 0 else 0
        else:
            max_topk_matches_per_route = 0
        
        # Criterion 3: Total number of topk matches across all routes
        num_topk_matches = group['topk'].sum() if 'topk' in group.columns else 0
        
        # Criterion 4: Maximum number of exact rxn type matches in a single route
        if 'pred_class' in group.columns and 'true_class' in group.columns and 'sample_route_idx' in group.columns:
            # For each route, count how many reactions have matching types
            route_match_counts = group.groupby('sample_route_idx').apply(
                lambda x: (x['pred_class'] == x['true_class']).sum(),
                include_groups=False
            )
            # Take the maximum across all routes
            max_rxn_type_matches_per_route = route_match_counts.max() if len(route_match_counts) > 0 else 0
        else:
            max_rxn_type_matches_per_route = 0
        
        # Criterion 5: Total number of exact rxn type matches across all routes
        if 'pred_class' in group.columns and 'true_class' in group.columns:
            num_exact_class_matches = (group['pred_class'] == group['true_class']).sum()
        else:
            num_exact_class_matches = 0
        
        # Criterion 6: Route was found
        route_found = group['solved'].any() if 'solved' in group.columns else False
        
        return pd.Series({
            'has_exact_route': has_exact_route,
            'max_topk_matches_per_route': max_topk_matches_per_route,
            'num_topk_matches': num_topk_matches,
            'max_rxn_type_matches_per_route': max_rxn_type_matches_per_route,
            'num_exact_class_matches': num_exact_class_matches,
            'route_found': route_found
        })
    
    # Group by product and experiment, compute metrics
    metrics = all_data.groupby(
        ['target_idx', 'experiment_name'], 
        as_index=False
    ).apply(compute_metrics, include_groups=False)
    
    # Sort by criteria in order of preference
    # For boolean columns, convert to int so True > False
    metrics_sorted = metrics.sort_values(
        by=[
            'target_idx',
            'has_exact_route',
            'max_topk_matches_per_route',
            'num_topk_matches',
            'max_rxn_type_matches_per_route',
            'num_exact_class_matches',
            'route_found'
        ],
        ascending=[True, False, False, False, False, False, False]
    )
    
    # Take best experiment per product
    best_experiments = metrics_sorted.groupby('target_idx').first().reset_index()
    
    # Join back to get full data for best experiments
    best_data = all_data.merge(
        best_experiments[['target_idx', 'experiment_name']], 
        on=['target_idx', 'experiment_name']
    )
    
    return best_data, best_experiments

def aggregate_guided_search_results(
    experiment_info: List[Dict[str, str]],
    experiment_base_dir: str,
    result_subdir: str = 'evaluations'
) -> pd.DataFrame:
    """
    Aggregate results from multiple guided search experiments, selecting best per product.
    
    Args:
        experiment_info: List of dicts with keys:
            - 'experiment_path': path to experiment directory
            - 'method_name': display name for the method
        experiment_base_dir: Base directory for experiments
        result_subdir: Subdirectory containing result CSVs (default: 'evaluations')
        
    Returns:
        DataFrame with aggregated metrics for each method
    """
    aggregates = []
    
    for exp in experiment_info:
        experiment_path = exp['experiment_path']
        method_name = exp['method_name']
        
        # Load results from all CSV files in the evaluations directory
        eval_dir = os.path.join(experiment_base_dir, experiment_path, result_subdir)
        
        if not os.path.exists(eval_dir):
            print(f"Warning: Directory not found: {eval_dir}")
            continue
            
        csv_files = [f for f in os.listdir(eval_dir) if f.endswith('.csv')]
        
        if not csv_files:
            print(f"Warning: No CSV files found in {eval_dir}")
            continue
        
        # Read all CSVs
        dfs = []
        for csv_file in csv_files:
            df = pd.read_csv(os.path.join(eval_dir, csv_file))
            dfs.append(df)
        
        # For single experiment, just concatenate all CSVs
        if len(dfs) == 1:
            best_data = dfs[0]
        else:
            # If multiple CSVs represent different runs/configs, select best per product
            experiment_names = [f"run_{i}" for i in range(len(dfs))]
            best_data, _ = select_best_search_experiment_per_product(dfs, experiment_names)
        
        # Calculate metrics on best data
        dataset_stats = calculate_dataset_aggregates(best_data)
        dataset_stats['method'] = method_name
        
        aggregates.append(dataset_stats)
    
    return pd.DataFrame(aggregates)


def aggregate_multiple_guided_experiments(
    experiment_groups: List[Dict[str, Any]],
    experiment_base_dir: str
) -> pd.DataFrame:
    """
    Aggregate results where each method has multiple experimental variations,
    selecting the best variation per product.
    
    Args:
        experiment_groups: List of dicts with keys:
            - 'method_name': display name for the method
            - 'experiment_paths': list of paths to different experiment variations
        experiment_base_dir: Base directory for experiments
        
    Returns:
        DataFrame with aggregated metrics for each method
    """
    aggregates = []
    
    for group in experiment_groups:
        method_name = group['method_name']
        experiment_paths = group['experiment_paths']
        
        # Load all experiments for this method
        all_dfs = []
        all_names = []
        
        for exp_path in experiment_paths:
            eval_dir = os.path.join(experiment_base_dir, exp_path, 'evaluations')
            
            if not os.path.exists(eval_dir):
                print(f"Warning: Directory not found: {eval_dir}")
                continue
                
            csv_files = [f for f in os.listdir(eval_dir) if f.endswith('.csv')]
            
            for csv_file in csv_files:
                df = pd.read_csv(os.path.join(eval_dir, csv_file))
                all_dfs.append(df)
                all_names.append(f"{os.path.basename(exp_path)}_{csv_file}")
        
        if not all_dfs:
            print(f"Warning: No data found for method {method_name}")
            continue
        
        # Select best experiment per product across all variations
        best_data, best_experiments = select_best_search_experiment_per_product(
            all_dfs, all_names
        )
        
        # Calculate metrics
        dataset_stats = calculate_dataset_aggregates(best_data)
        dataset_stats['method'] = method_name
        
        aggregates.append(dataset_stats)
    
    return pd.DataFrame(aggregates)

from typing import List, Dict, Optional
import os

"""
Modified table generation functions with composite key support
to handle duplicate method names across different groups.
"""

def get_manual_synthesis_metrics_table_reaction_type(
    experiment_info,
    aggregates,
    table_name,
    save_table=True,
    caption='Manual synthesis metrics',
    label='tab:manual-synthesis-metrics'
):
    """
    Get latex table for manual synthesis metrics.
    
    Uses composite keys (method_name + index) to handle duplicate method names
    across different training datasets.
    """
    # Create unique identifiers for each experiment using index
    method_names = []
    method_categories = {}
    method_groups = {}
    
    for i, e in enumerate(experiment_info):
        # Use index as part of key to ensure uniqueness
        unique_key = f"{e['method_name']}_{i}"
        method_names.append(unique_key)
        method_categories[unique_key] = e['category']
        method_groups[unique_key] = e['trained_on']
    
    # Update the aggregates to use the same unique keys
    updated_aggregates = []
    for i, agg in enumerate(aggregates):
        e = experiment_info[i]
        unique_key = f"{e['method_name']}_{i}"
        updated_agg = {
            **agg,
            'method': unique_key,
            'display_name': e['method_name']  # Keep original name for display
        }
        updated_aggregates.append(updated_agg)
    
    metric_display_names = {
        'completion_rate': r'Exact',
        'avg_topk_1': r'top-1',
        'avg_topk_5': r'top-5',
        'avg_topk_50': r'top-50',
        #'avg_topk_100': r'top-100',
        'perc_samples_per_product': r'Unique',
        #'percentage_products_with_class_correct': r'Correct',
        #'avg_class_correct_samples_per_product': r'Correct',
        'perc_class_correct_samples_per_product': r'Correct',
        #'percentage_products_with_class_and_round_trip_correct': r'Class and',
        #'products_with_max_tanimoto_to_starting': r'TA',
        'avg_tanimoto_to_starting': r'TA',
        #'percentage_products_with_round_trip_correct': r'Correct',
        'perc_round_trip_correct_samples_per_product': r'Correct',
    }
    metric_display_names_line2 = {
        'completion_rate': r'route ($\uparrow$)',
        'avg_topk_1': r'($\uparrow$)',
        'avg_topk_5': r'($\uparrow$)',
        'avg_topk_50': r'($\uparrow$)',
        'avg_topk_100': r'($\uparrow$)',
        'perc_samples_per_product': r'samples ($\uparrow$)',
        #'percentage_products_with_class_correct': r'class ($\uparrow$)',
        #'avg_class_correct_samples_per_product': r'class ($\uparrow$)',
        #'percentage_products_with_class_and_round_trip_correct': r'R.Trip ($\uparrow$)',
        #'avg_tanimoto_to_starting': r'to SM ($\uparrow$)',
        'perc_class_correct_samples_per_product': r'class ($\uparrow$)',
        #'products_with_max_tanimoto_to_starting': r'to SM ($\uparrow$)',
        'avg_tanimoto_to_starting': r'to SM ($\uparrow$)',
        #'percentage_products_with_round_trip_correct': r'R.Trip ($\uparrow$)',
        'perc_round_trip_correct_samples_per_product': r'R.Trip ($\uparrow$)',
    }
    metrics = metric_display_names.keys()
    bold_best = {
        'completion_rate': 'high',
        'avg_topk_1': 'high',
        'avg_topk_5': 'high',
        'avg_topk_50': 'high',
        'avg_topk_100': 'high',
        'perc_samples_per_product': 'high',
        'percentage_products_with_class_correct': 'high',
        'avg_class_correct_samples_per_product': 'high',
        'avg_tanimoto_to_starting': 'high',
        'perc_class_correct_samples_per_product': 'high',
        'percentage_products_with_class_and_round_trip_correct': 'high',
        'products_with_max_tanimoto_to_starting': 'high',
        'percentage_products_with_round_trip_correct': 'high',
        'perc_round_trip_correct_samples_per_product': 'high',
    }
    
    latex_table = generate_latex_table_manual_synthesis(
        experiment_dirs=experiment_info,
        metrics=metrics,
        results=updated_aggregates,  # Use updated aggregates with unique keys
        method_names=method_names,
        method_categories=method_categories,
        method_groups=method_groups,
        caption=caption,
        label=label,
        metric_display_names=metric_display_names,
        metric_display_names_line2=metric_display_names_line2,
        bold_best=bold_best,
        decimal_places=2,
        highlight_methods=['Rsmiles-TG$_{\\text{rxn}}$', 'Rsmiles-TG$_{\\text{sim}}$'],
    )
    
    # Uncomment and adjust paths as needed
    latex_output_path = os.path.join(
        PROJECT_ROOT,
        'paper',
        'iclr2026',
        'tables',
        table_name
    )
    if save_table:
        save_latex_table(
            latex_table=latex_table,
            output_path=latex_output_path,
            standalone=False
        )
    
    return latex_table

def generate_latex_table_manual_synthesis(
    experiment_dirs: List[str],
    metrics: List[str],
    method_names: List[str],
    caption: str = "Sample quality in synthesis planning",
    label: str = "tab:results",
    metric_display_names: Optional[Dict[str, str]] = None,
    metric_display_names_line2: Optional[Dict[str, str]] = None,
    decimal_places: int = 2,
    bold_best: Optional[Dict[str, str]] = None,
    results: Optional[List[Dict]] = None,
    method_categories: Optional[Dict[str, str]] = None,
    method_groups: Optional[Dict[str, str]] = None,
    use_siunitx: bool = False,
    font_size: str = "small",
    tabcolsep: Optional[str] = "4pt",
    group_header_spacing: str = "2pt",
    group_separation: str = "4pt",
    highlight_per_group: bool = True,
    highlight_methods: Optional[List[str]] = None,
    highlight_color: str = "highlightgreen",
) -> str:
    """
    Generate a LaTeX table for manual synthesis metrics.
    
    Supports composite keys for method identification while displaying
    clean method names in the table.
    
    Args:
        highlight_per_group: If True, highlights best/second-best within each group.
                           If False, highlights best/second-best globally across all methods.
    """
    if len(experiment_dirs) != len(method_names):
        raise ValueError("Number of experiment directories must match number of method names")
    
    if metric_display_names is None:
        metric_display_names = {k: k for k in metrics}
    
    if metric_display_names_line2 is None:
        metric_display_names_line2 = {k: '' for k in metrics}
    
    if method_categories is None:
        method_categories = {name: 'N.T.' for name in method_names}
    
    if method_groups is None:
        method_groups = {name: None for name in method_names}
    
    if highlight_methods is None:
        highlight_methods = []
    # Collect results if not provided
    if results is None:
        results = []
        for exp_dir, method_name in zip(experiment_dirs, method_names):
            eval_dir = os.path.join(exp_dir)
            csv_files = [f for f in os.listdir(eval_dir) if f.endswith('.csv')]
            dfs = [pd.read_csv(os.path.join(eval_dir, f)) for f in csv_files]
            df = pd.concat(dfs, ignore_index=True)
            
            # from your_module import calculate_dataset_aggregates
            # dataset_stats = calculate_dataset_aggregates(df)
            dataset_stats = {}  # Placeholder
            
            row_data = {'method': method_name, 'display_name': method_name}
            for metric in metrics:
                row_data[metric] = dataset_stats.get(metric, None)
            results.append(row_data)
    
    # Add metadata to results, preserving display_name if it exists
    results = [
        {
            **r,
            'category': method_categories.get(r['method'], None),
            'group': method_groups.get(r['method'], None),
            'display_name': r.get('display_name', r['method'])  # Use existing or fall back to method
        }
        for r in results
    ]
    col_spec = f'lc' + 'c' * len(metrics)
    
    # Build header
    header_line1_cols = ['Method', 'Temp.']
    header_line2_cols = ['', '']
    
    for m in metrics:
        line1 = metric_display_names.get(m, m)
        line2 = metric_display_names_line2.get(m, '')
        
        if line2:
            header_line1_cols.append(line1)
            header_line2_cols.append(line2)
        else:
            header_line1_cols.append(f"\\multirow{{2}}{{*}}{{{line1}}}")
            header_line2_cols.append('')
    
    header_line1 = ' & '.join(header_line1_cols) + ' \\\\'
    header_line2 = ' & '.join(header_line2_cols) + ' \\\\'
    
    # Find best values - either per group or globally
    def compute_best_values(results_subset, metrics, bold_best, decimal_places):
        """Compute best and second-best values for a subset of results."""
        best_values = {}
        second_best_values = {}
        
        if bold_best:
            for metric in metrics:
                if metric in bold_best:
                    values = [r[metric] for r in results_subset 
                             if r[metric] is not None and isinstance(r[metric], (int, float))]
                    if len(values) >= 1:
                        rounded_values = [round(v, decimal_places) for v in values]
                        unique_values = sorted(set(rounded_values), reverse=(bold_best[metric] == 'high'))
                        best_values[metric] = unique_values[0]
                        if len(unique_values) >= 2:
                            second_best_values[metric] = unique_values[1]
        
        return best_values, second_best_values
    
    # Compute best/second-best values based on highlight_per_group setting
    if highlight_per_group and bold_best:
        # Group results by their group label
        groups = {}
        for r in results:
            group_key = r['group']
            if group_key not in groups:
                groups[group_key] = []
            groups[group_key].append(r)
        
        # Compute best/second-best per group
        group_best_values = {}
        group_second_best_values = {}
        for group_key, group_results in groups.items():
            best, second_best = compute_best_values(group_results, metrics, bold_best, decimal_places)
            group_best_values[group_key] = best
            group_second_best_values[group_key] = second_best
        
        # Initialize global variables as empty (won't be used)
        global_best = {}
        global_second_best = {}
    else:
        # Global best/second-best across all methods
        global_best, global_second_best = compute_best_values(results, metrics, bold_best, decimal_places)
        
        # Initialize group variables as empty (won't be used)
        group_best_values = {}
        group_second_best_values = {}
    
    # Generate rows with grouping
    data_lines = []
    current_group = None
    num_cols = 2 + len(metrics)
    is_first_group = True
    
    for i, result in enumerate(results):
        # Add group header if needed
        if result['group'] != current_group and result['group'] is not None:
            # Add extra space before new group (except for the first group)
            if not is_first_group:
                if data_lines:
                    data_lines[-1] = data_lines[-1].replace(' \\\\', f' \\\\[{group_separation}]')
            group_line = f"    \\multicolumn{{{num_cols}}}{{l}}{{\\textit{{{result['group']}}}}} \\\\[{group_header_spacing}]"
            data_lines.append(group_line)
            current_group = result['group']
            is_first_group = False
        
        # Get the appropriate best/second-best values for this row
        if highlight_per_group and bold_best:
            best_values = group_best_values.get(result['group'], {})
            second_best_values = group_second_best_values.get(result['group'], {})
        else:
            # Use global best/second-best values
            best_values = global_best if bold_best else {}
            second_best_values = global_second_best if bold_best else {}
        
        # Use display_name for the table, but method (unique key) for lookups
        display_name = result['display_name']
        method_key = result['method']
        category = result['category']
        
        values = []
        for metric in metrics:
            val = result[metric]
            if val is None:
                values.append('{--}' if use_siunitx else '--')
            elif isinstance(val, (int, float)):
                formatted_val = f"{val:.{decimal_places}f}"
                rounded_val = round(val, decimal_places)
                
                is_best = (metric in best_values and 
                          round(best_values[metric], decimal_places) == rounded_val)
                is_second_best = (metric in second_best_values and 
                                 round(second_best_values[metric], decimal_places) == rounded_val)
                
                if use_siunitx:
                    if is_best:
                        values.append(f"{{\\bfseries {formatted_val}}}")
                    elif is_second_best:
                        values.append(f"{{\\underline{{{formatted_val}}}}}")
                    else:
                        values.append(formatted_val)
                else:
                    if is_best:
                        values.append(f"\\textbf{{{formatted_val}}}")
                    elif is_second_best:
                        values.append(f"\\underline{{{formatted_val}}}")
                    else:
                        values.append(formatted_val)
            else:
                values.append(str(val))
        
        # Add indentation for grouped methods
        indent = "\\quad " if current_group is not None else ""
        
        # Check if this method should be highlighted (check both display_name and method_key)
        should_highlight = display_name in highlight_methods or method_key in highlight_methods
        
        if should_highlight:
            row = f"    \\rowcolor{{{highlight_color}}}\n    {indent}{display_name} & {category} & " + ' & '.join(values) + ' \\\\'
        else:
            row = f"    {indent}{display_name} & {category} & " + ' & '.join(values) + ' \\\\'
        
        data_lines.append(row)

    # Build formatting preamble
    format_commands = []
    if font_size and font_size != "normal":
        format_commands.append(f"\\{font_size}")
    if tabcolsep:
        format_commands.append(f"\\setlength{{\\tabcolsep}}{{{tabcolsep}}}")
    
    format_preamble = "\n".join(format_commands)
    if format_preamble:
        format_preamble += "\n"

    latex_table = f"""\\begin{{table}}[htb!]
\\caption{{{caption}}}
\\label{{{label}}}
\\centering
{format_preamble}\\begin{{tabular}}{{{col_spec}}}
    \\toprule
    {header_line1}
    {header_line2}
    \\midrule
{chr(10).join(data_lines)}
    \\bottomrule
\\end{{tabular}}
\\end{{table}}"""
    return latex_table


def generate_latex_table_manual_synthesis_old(
    experiment_dirs: List[str],
    metrics: List[str],
    method_names: List[str],
    caption: str = "Sample quality in synthesis planning",
    label: str = "tab:results",
    metric_display_names: Optional[Dict[str, str]] = None,
    metric_display_names_line2: Optional[Dict[str, str]] = None,
    decimal_places: int = 2,
    bold_best: Optional[Dict[str, str]] = None,
    results: Optional[List[Dict]] = None,
    method_categories: Optional[Dict[str, str]] = None,
    method_groups: Optional[Dict[str, str]] = None,
    use_siunitx: bool = False,
    font_size: str = "small",
    tabcolsep: Optional[str] = "4pt",
    group_header_spacing: str = "2pt",
    group_separation: str = "4pt",
    highlight_per_group: bool = True,
    highlight_methods: Optional[List[str]] = None,
    highlight_color: str = "highlightgreen",
) -> str:
    """
    Generate a LaTeX table for manual synthesis metrics.
    
    Supports composite keys for method identification while displaying
    clean method names in the table.
    """
    if len(experiment_dirs) != len(method_names):
        raise ValueError("Number of experiment directories must match number of method names")
    
    if metric_display_names is None:
        metric_display_names = {k: k for k in metrics}
    
    if metric_display_names_line2 is None:
        metric_display_names_line2 = {k: '' for k in metrics}
    
    if method_categories is None:
        method_categories = {name: 'N.T.' for name in method_names}
    
    if method_groups is None:
        method_groups = {name: None for name in method_names}
    
    if highlight_methods is None:
        highlight_methods = []
    # Collect results if not provided
    if results is None:
        results = []
        for exp_dir, method_name in zip(experiment_dirs, method_names):
            eval_dir = os.path.join(exp_dir)
            csv_files = [f for f in os.listdir(eval_dir) if f.endswith('.csv')]
            dfs = [pd.read_csv(os.path.join(eval_dir, f)) for f in csv_files]
            df = pd.concat(dfs, ignore_index=True)
            
            # from your_module import calculate_dataset_aggregates
            # dataset_stats = calculate_dataset_aggregates(df)
            dataset_stats = {}  # Placeholder
            
            row_data = {'method': method_name, 'display_name': method_name}
            for metric in metrics:
                row_data[metric] = dataset_stats.get(metric, None)
            results.append(row_data)
    
    # Add metadata to results, preserving display_name if it exists
    results = [
        {
            **r,
            'category': method_categories.get(r['method'], None),
            'group': method_groups.get(r['method'], None),
            'display_name': r.get('display_name', r['method'])  # Use existing or fall back to method
        }
        for r in results
    ]
    print('name', [r['method'] for r in results])
    print([r['group'] for r in results])
    col_spec = f'lc' + 'c' * len(metrics)
    
    # Build header
    header_line1_cols = ['Method', 'Temp.']
    header_line2_cols = ['', '']
    
    for m in metrics:
        line1 = metric_display_names.get(m, m)
        line2 = metric_display_names_line2.get(m, '')
        
        if line2:
            header_line1_cols.append(line1)
            header_line2_cols.append(line2)
        else:
            header_line1_cols.append(f"\\multirow{{2}}{{*}}{{{line1}}}")
            header_line2_cols.append('')
    
    header_line1 = ' & '.join(header_line1_cols) + ' \\\\'
    header_line2 = ' & '.join(header_line2_cols) + ' \\\\'
    
    # Find best values - either per group or globally
    def compute_best_values(results_subset, metrics, bold_best, decimal_places):
        """Compute best and second-best values for a subset of results."""
        best_values = {}
        second_best_values = {}
        
        if bold_best:
            for metric in metrics:
                if metric in bold_best:
                    values = [r[metric] for r in results_subset 
                             if r[metric] is not None and isinstance(r[metric], (int, float))]
                    if len(values) >= 1:
                        rounded_values = [round(v, decimal_places) for v in values]
                        unique_values = sorted(set(rounded_values), reverse=(bold_best[metric] == 'high'))
                        best_values[metric] = unique_values[0]
                        if len(unique_values) >= 2:
                            second_best_values[metric] = unique_values[1]
        
        return best_values, second_best_values
    
    if highlight_per_group and bold_best:
        # Group results by their group label
        groups = {}
        for r in results:
            group_key = r['group']
            if group_key not in groups:
                groups[group_key] = []
            groups[group_key].append(r)
        
        # Compute best/second-best per group
        group_best_values = {}
        group_second_best_values = {}
        for group_key, group_results in groups.items():
            best, second_best = compute_best_values(group_results, metrics, bold_best, decimal_places)
            group_best_values[group_key] = best
            group_second_best_values[group_key] = second_best
    else:
        # Global best/second-best (original behavior)
        global_best, global_second_best = compute_best_values(results, metrics, bold_best, decimal_places)
    
    # Generate rows with grouping
    data_lines = []
    current_group = None
    num_cols = 2 + len(metrics)
    is_first_group = True
    
    for i, result in enumerate(results):
        # Add group header if needed
        print(result['group'])
        if result['group'] != current_group and result['group'] is not None:
            # Add extra space before new group (except for the first group)
            if not is_first_group:
                if data_lines:
                    data_lines[-1] = data_lines[-1].replace(' \\\\', f' \\\\[{group_separation}]')
            print(result['group'])
            group_line = f"    \\multicolumn{{{num_cols}}}{{l}}{{\\textit{{{result['group']}}}}} \\\\[{group_header_spacing}]"
            data_lines.append(group_line)
            current_group = result['group']
            is_first_group = False
        
        # Get the appropriate best/second-best values for this row
        if highlight_per_group and bold_best:
            best_values = group_best_values.get(result['group'], {})
            second_best_values = group_second_best_values.get(result['group'], {})
        else:
            best_values = global_best if bold_best else {}
            second_best_values = global_second_best if bold_best else {}
        
        # Use display_name for the table, but method (unique key) for lookups
        display_name = result['display_name']
        method_key = result['method']
        category = result['category']
        
        values = []
        for metric in metrics:
            val = result[metric]
            if val is None:
                values.append('{--}' if use_siunitx else '--')
            elif isinstance(val, (int, float)):
                formatted_val = f"{val:.{decimal_places}f}"
                rounded_val = round(val, decimal_places)
                
                is_best = (metric in best_values and 
                          round(best_values[metric], decimal_places) == rounded_val)
                is_second_best = (metric in second_best_values and 
                                 round(second_best_values[metric], decimal_places) == rounded_val)
                
                if use_siunitx:
                    if is_best:
                        values.append(f"{{\\bfseries {formatted_val}}}")
                    elif is_second_best:
                        values.append(f"{{\\underline{{{formatted_val}}}}}")
                    else:
                        values.append(formatted_val)
                else:
                    if is_best:
                        values.append(f"\\textbf{{{formatted_val}}}")
                    elif is_second_best:
                        values.append(f"\\underline{{{formatted_val}}}")
                    else:
                        values.append(formatted_val)
            else:
                values.append(str(val))
        
        # Add indentation for grouped methods
        indent = "\\quad " if current_group is not None else ""
        
        # Check if this method should be highlighted (check both display_name and method_key)
        should_highlight = display_name in highlight_methods or method_key in highlight_methods
        
        if should_highlight:
            row = f"    \\rowcolor{{{highlight_color}}}\n    {indent}{display_name} & {category} & " + ' & '.join(values) + ' \\\\'
        else:
            row = f"    {indent}{display_name} & {category} & " + ' & '.join(values) + ' \\\\'
        
        data_lines.append(row)

    # Build formatting preamble
    format_commands = []
    if font_size and font_size != "normal":
        format_commands.append(f"\\{font_size}")
    if tabcolsep:
        format_commands.append(f"\\setlength{{\\tabcolsep}}{{{tabcolsep}}}")
    
    format_preamble = "\n".join(format_commands)
    if format_preamble:
        format_preamble += "\n"

    latex_table = f"""\\begin{{table}}[htb!]
\\caption{{{caption}}}
\\label{{{label}}}
\\centering
{format_preamble}\\begin{{tabular}}{{{col_spec}}}
    \\toprule
    {header_line1}
    {header_line2}
    \\midrule
{chr(10).join(data_lines)}
    \\bottomrule
\\end{{tabular}}
\\end{{table}}"""
    return latex_table

# def generate_latex_table_manual_synthesis(
#     experiment_dirs: List[str],
#     metrics: List[str],
#     method_names: List[str],
#     caption: str = "Sample quality in synthesis planning",
#     label: str = "tab:results",
#     metric_display_names: Optional[Dict[str, str]] = None,
#     metric_display_names_line2: Optional[Dict[str, str]] = None,
#     decimal_places: int = 2,
#     bold_best: Optional[Dict[str, str]] = None,
#     results: Optional[List[Dict]] = None,
#     method_categories: Optional[Dict[str, str]] = None,
#     method_groups: Optional[Dict[str, str]] = None,
#     use_siunitx: bool = False,
#     font_size: str = "small",
#     tabcolsep: Optional[str] = "4pt",
#     group_header_spacing: str = "2pt",
#     group_separation: str = "4pt",
#     highlight_per_group: bool = True,
#     highlight_methods: Optional[List[str]] = None,  # NEW: list of method names to highlight
#     highlight_color: str = "highlightgreen",  # NEW: color name for highlighting
# ) -> str:
#     if len(experiment_dirs) != len(method_names):
#         raise ValueError("Number of experiment directories must match number of method names")
    
#     if metric_display_names is None:
#         metric_display_names = {k: k for k in metrics}
    
#     if metric_display_names_line2 is None:
#         metric_display_names_line2 = {k: '' for k in metrics}
    
#     if method_categories is None:
#         method_categories = {name: 'N.T.' for name in method_names}
    
#     if method_groups is None:
#         method_groups = {name: None for name in method_names}
    
#     if highlight_methods is None:
#         highlight_methods = []
    
#     # Collect results
#     if results is None:
#         results = []
#         for exp_dir, method_name in zip(experiment_dirs, method_names):
#             eval_dir = os.path.join(exp_dir)
#             csv_files = [f for f in os.listdir(eval_dir) if f.endswith('.csv')]
#             dfs = [pd.read_csv(os.path.join(eval_dir, f)) for f in csv_files]
#             df = pd.concat(dfs, ignore_index=True)
            
#             from your_module import calculate_dataset_aggregates
#             dataset_stats = calculate_dataset_aggregates(df)
            
#             row_data = {'method': method_name}
#             for metric in metrics:
#                 row_data[metric] = dataset_stats.get(metric, None)
#             results.append(row_data)
    
#     # Add metadata to results
#     results = [
#         {**r, 'category': method_categories.get(r['method'], None),
#          'group': method_groups.get(r['method'], None)}
#         for r in results
#     ]
    
#     col_spec = f'lc' + 'c' * len(metrics)
    
#     # Build header
#     header_line1_cols = ['Method', 'Temp.']
#     header_line2_cols = ['', '']
    
#     for m in metrics:
#         line1 = metric_display_names.get(m, m)
#         line2 = metric_display_names_line2.get(m, '')
        
#         if line2:
#             header_line1_cols.append(line1)
#             header_line2_cols.append(line2)
#         else:
#             header_line1_cols.append(f"\\multirow{{2}}{{*}}{{{line1}}}")
#             header_line2_cols.append('')
    
#     header_line1 = ' & '.join(header_line1_cols) + ' \\\\'
#     header_line2 = ' & '.join(header_line2_cols) + ' \\\\'
    
#     # Find best values - either per group or globally
#     def compute_best_values(results_subset, metrics, bold_best, decimal_places):
#         """Compute best and second-best values for a subset of results."""
#         best_values = {}
#         second_best_values = {}
        
#         if bold_best:
#             for metric in metrics:
#                 if metric in bold_best:
#                     values = [r[metric] for r in results_subset 
#                              if r[metric] is not None and isinstance(r[metric], (int, float))]
#                     if len(values) >= 1:
#                         rounded_values = [round(v, decimal_places) for v in values]
#                         unique_values = sorted(set(rounded_values), reverse=(bold_best[metric] == 'high'))
#                         best_values[metric] = unique_values[0]
#                         if len(unique_values) >= 2:
#                             second_best_values[metric] = unique_values[1]
        
#         return best_values, second_best_values
    
#     if highlight_per_group and bold_best:
#         # Group results by their group label
#         groups = {}
#         for r in results:
#             group_key = r['group']
#             if group_key not in groups:
#                 groups[group_key] = []
#             groups[group_key].append(r)
        
#         # Compute best/second-best per group
#         group_best_values = {}
#         group_second_best_values = {}
#         for group_key, group_results in groups.items():
#             best, second_best = compute_best_values(group_results, metrics, bold_best, decimal_places)
#             group_best_values[group_key] = best
#             group_second_best_values[group_key] = second_best
#     else:
#         # Global best/second-best (original behavior)
#         global_best, global_second_best = compute_best_values(results, metrics, bold_best, decimal_places)
    
#     # Generate rows with grouping
#     data_lines = []
#     current_group = None
#     num_cols = 2 + len(metrics)
#     is_first_group = True
    
#     for i, result in enumerate(results):
#         # Add group header if needed
#         if result['group'] != current_group and result['group'] is not None:
#             # Add extra space before new group (except for the first group)
#             if not is_first_group:
#                 if data_lines:
#                     data_lines[-1] = data_lines[-1].replace(' \\\\', f' \\\\[{group_separation}]')
            
#             group_line = f"    \\multicolumn{{{num_cols}}}{{l}}{{\\textit{{{result['group']}}}}} \\\\[{group_header_spacing}]"
#             data_lines.append(group_line)
#             current_group = result['group']
#             is_first_group = False
        
#         # Get the appropriate best/second-best values for this row
#         if highlight_per_group and bold_best:
#             best_values = group_best_values.get(result['group'], {})
#             second_best_values = group_second_best_values.get(result['group'], {})
#         else:
#             best_values = global_best if bold_best else {}
#             second_best_values = global_second_best if bold_best else {}
        
#         method = result['method']
#         category = result['category']
        
#         values = []
#         for metric in metrics:
#             val = result[metric]
#             if val is None:
#                 values.append('{--}' if use_siunitx else '--')
#             elif isinstance(val, (int, float)):
#                 formatted_val = f"{val:.{decimal_places}f}"
#                 rounded_val = round(val, decimal_places)
                
#                 is_best = (metric in best_values and 
#                           round(best_values[metric], decimal_places) == rounded_val)
#                 is_second_best = (metric in second_best_values and 
#                                  round(second_best_values[metric], decimal_places) == rounded_val)
                
#                 if use_siunitx:
#                     if is_best:
#                         values.append(f"{{\\bfseries {formatted_val}}}")
#                     elif is_second_best:
#                         values.append(f"{{\\underline{{{formatted_val}}}}}")
#                     else:
#                         values.append(formatted_val)
#                 else:
#                     if is_best:
#                         values.append(f"\\textbf{{{formatted_val}}}")
#                     elif is_second_best:
#                         values.append(f"\\underline{{{formatted_val}}}")
#                     else:
#                         values.append(formatted_val)
#             else:
#                 values.append(str(val))
        
#         # Add indentation for grouped methods
#         indent = "\\quad " if current_group is not None else ""
        
#         # Check if this method should be highlighted
#         should_highlight = method in highlight_methods
        
#         if should_highlight:
#             row = f"    \\rowcolor{{{highlight_color}}}\n    {indent}{method} & {category} & " + ' & '.join(values) + ' \\\\'
#         else:
#             row = f"    {indent}{method} & {category} & " + ' & '.join(values) + ' \\\\'
        
#         data_lines.append(row)

#     # Build formatting preamble
#     format_commands = []
#     if font_size and font_size != "normal":
#         format_commands.append(f"\\{font_size}")
#     if tabcolsep:
#         format_commands.append(f"\\setlength{{\\tabcolsep}}{{{tabcolsep}}}")
    
#     format_preamble = "\n".join(format_commands)
#     if format_preamble:
#         format_preamble += "\n"

#     latex_table = f"""\\begin{{table}}[htb!]
# \\caption{{{caption}}}
# \\label{{{label}}}
# \\centering
# {format_preamble}\\begin{{tabular}}{{{col_spec}}}
#     \\toprule
#     {header_line1}
#     {header_line2}
#     \\midrule
# {chr(10).join(data_lines)}
#     \\bottomrule
# \\end{{tabular}}
# \\end{{table}}"""
#     return latex_table
    
def generate_latex_table(
    experiment_dirs: List[str],
    metrics: List[str],
    method_names: List[str],
    caption: str = "Sample quality in synthesis planning",
    label: str = "tab:results",
    metric_display_names: Optional[Dict[str, str]] = None,
    metric_display_names_line2: Optional[Dict[str, str]] = None,
    decimal_places: int = 2,
    bold_best: Optional[Dict[str, str]] = None,
    results: Optional[List[Dict]] = None,
    method_categories: Optional[Dict[str, str]] = None,
    method_groups: Optional[Dict[str, str]] = None,  # NEW: maps method_name to group label
    use_siunitx: bool = False,
) -> str:
    if len(experiment_dirs) != len(method_names):
        raise ValueError("Number of experiment directories must match number of method names")
    
    if metric_display_names is None:
        metric_display_names = {k: k for k in metrics}
    
    if metric_display_names_line2 is None:
        metric_display_names_line2 = {k: '' for k in metrics}
    
    if method_categories is None:
        method_categories = {name: 'N.T.' for name in method_names}
    
    if method_groups is None:
        method_groups = {name: None for name in method_names}
    
    # Collect results
    if results is None:
        results = []
        for exp_dir, method_name in zip(experiment_dirs, method_names):
            eval_dir = os.path.join(exp_dir)
            csv_files = [f for f in os.listdir(eval_dir) if f.endswith('.csv')]
            dfs = [pd.read_csv(os.path.join(eval_dir, f)) for f in csv_files]
            df = pd.concat(dfs, ignore_index=True)
            
            from your_module import calculate_dataset_aggregates
            dataset_stats = calculate_dataset_aggregates(df)
            
            row_data = {'method': method_name}
            for metric in metrics:
                row_data[metric] = dataset_stats.get(metric, None)
            results.append(row_data)
    
    # Add metadata to results
    results = [
        {**r, 'category': method_categories.get(r['method'], 'N.T.'),
         'group': method_groups.get(r['method'], None)}
        for r in results
    ]
    
    col_spec = f'lc' + 'c' * len(metrics)
    
    # Build header
    header_line1_cols = ['Method', 'Temp.']
    header_line2_cols = ['', '']
    
    for m in metrics:
        line1 = metric_display_names.get(m, m)
        line2 = metric_display_names_line2.get(m, '')
        
        if line2:
            header_line1_cols.append(line1)
            header_line2_cols.append(line2)
        else:
            header_line1_cols.append(f"\\multirow{{2}}{{*}}{{{line1}}}")
            header_line2_cols.append('')
    
    header_line1 = ' & '.join(header_line1_cols) + ' \\\\'
    header_line2 = ' & '.join(header_line2_cols) + ' \\\\'
    
    # Find best values
    best_values = {}
    second_best_values = {}
    
    if bold_best:
        for metric in metrics:
            if metric in bold_best:
                values = [r[metric] for r in results 
                         if r[metric] is not None and isinstance(r[metric], (int, float))]
                if len(values) >= 1:
                    rounded_values = [round(v, decimal_places) for v in values]
                    unique_values = sorted(set(rounded_values), reverse=(bold_best[metric] == 'high'))
                    best_values[metric] = unique_values[0]
                    if len(unique_values) >= 2:
                        second_best_values[metric] = unique_values[1]
    
    # Generate rows with grouping
    data_lines = []
    current_group = None
    num_cols = 2 + len(metrics)
    
    for result in results:
        # Add group header if needed
        if result['group'] != current_group and result['group'] is not None:
            group_line = f"    \\multicolumn{{{num_cols}}}{{l}}{{\\textit{{{result['group']}}}}} \\\\"
            data_lines.append(group_line)
            current_group = result['group']
        
        method = result['method']
        category = result['category']
        
        values = []
        for metric in metrics:
            val = result[metric]
            if val is None:
                values.append('{--}' if use_siunitx else '--')
            elif isinstance(val, (int, float)):
                formatted_val = f"{val:.{decimal_places}f}"
                rounded_val = round(val, decimal_places)
                
                is_best = (metric in best_values and 
                          round(best_values[metric], decimal_places) == rounded_val)
                is_second_best = (metric in second_best_values and 
                                 round(second_best_values[metric], decimal_places) == rounded_val)
                
                if use_siunitx:
                    if is_best:
                        values.append(f"{{\\bfseries {formatted_val}}}")
                    elif is_second_best:
                        values.append(f"{{\\underline{{{formatted_val}}}}}")
                    else:
                        values.append(formatted_val)
                else:
                    if is_best:
                        values.append(f"\\textbf{{{formatted_val}}}")
                    elif is_second_best:
                        values.append(f"\\underline{{{formatted_val}}}")
                    else:
                        values.append(formatted_val)
            else:
                values.append(str(val))
        
        # Add indentation for grouped methods
        indent = "\\quad " if current_group is not None else ""
        row = f"    {indent}{method} & {category} & " + ' & '.join(values) + ' \\\\'
        data_lines.append(row)

    latex_table = f"""\\begin{{table}}[htb!]
\\caption{{{caption}}}
\\label{{{label}}}
\\centering
\\begin{{tabular}}{{{col_spec}}}
    \\toprule
    {header_line1}
    {header_line2}
    \\midrule
{chr(10).join(data_lines)}
    \\bottomrule
\\end{{tabular}}
\\end{{table}}"""
    
    return latex_table

def generate_latex_table_without_trained_on(
    experiment_dirs: List[str],
    metrics: List[str],
    method_names: List[str],
    caption: str = "Sample quality in synthesis planning",
    label: str = "tab:results",
    metric_display_names: Optional[Dict[str, str]] = None,
    metric_display_names_line2: Optional[Dict[str, str]] = None,  # Second line of header
    decimal_places: int = 2,
    bold_best: Optional[Dict[str, str]] = None,
    results: Optional[List[Dict]] = None,
    method_categories: Optional[Dict[str, str]] = None,
    exact_match_metrics: Optional[List[str]] = None,
    use_siunitx: bool = False,  # For better number alignment
) -> str:
    if len(experiment_dirs) != len(method_names):
        raise ValueError("Number of experiment directories must match number of method names")
    
    # Default display names if not provided
    if metric_display_names is None:
        metric_display_names = {k: k for k in metrics}
    
    # Default second line (empty for single-line headers)
    print(metric_display_names_line2)
    if metric_display_names_line2 is None:
        metric_display_names_line2 = {k: '' for k in metrics}
    
    # Default categories if not provided
    if method_categories is None:
        method_categories = {name: 'N.T.' for name in method_names}
    
    # Collect data from all experiments
    if results is None:
        results = []
        for exp_dir, method_name in zip(experiment_dirs, method_names):
            eval_dir = os.path.join(exp_dir)
            
            csv_files = [f for f in os.listdir(eval_dir) if f.endswith('.csv')]
            dfs = [pd.read_csv(os.path.join(eval_dir, f)) for f in csv_files]
            df = pd.concat(dfs, ignore_index=True)
            
            # Calculate metrics (you'll need to import this function)
            from your_module import calculate_dataset_aggregates
            dataset_stats = calculate_dataset_aggregates(df)
            
            row_data = {'method': method_name}
            for metric in metrics:
                row_data[metric] = dataset_stats.get(metric, None)
            results.append(row_data)
    
    col_spec = f'lc' + 'c' * len(metrics)
    
    # Build two-line header
    header_line1_cols = ['Method', 'Temp.']
    header_line2_cols = ['', '']
    
    for m in metrics:
        line1 = metric_display_names.get(m, m)
        line2 = metric_display_names_line2.get(m, '')
        #print(line1, line2)
        
        if line2:
            # Two-line header
            header_line1_cols.append(line1)
            header_line2_cols.append(line2)
        else:
            # Single-line header (use multirow)
            header_line1_cols.append(f"\\multirow{{2}}{{*}}{{{line1}}}")
            header_line2_cols.append('')
    
    header_line1 = ' & '.join(header_line1_cols) + ' \\\\'
    header_line2 = ' & '.join(header_line2_cols) + ' \\\\'
    
    # Find best and second-best values for each metric
    best_values = {}
    second_best_values = {}
    
    if bold_best:
        for metric in metrics:
            if metric in bold_best:
                values = [
                    #r[metric] for r in results_sorted 
                    r[metric] for r in results 
                    if r[metric] is not None and isinstance(r[metric], (int, float))
                ]
                if len(values) >= 1:
                    rounded_values = [round(v, decimal_places) for v in values]
                    unique_values = sorted(set(rounded_values), reverse=(bold_best[metric] == 'high'))
                    best_values[metric] = unique_values[0]
                    if len(unique_values) >= 2:
                        second_best_values[metric] = unique_values[1]
    
    # Generate data rows
    data_lines = []
    # Generate rows
    results = [
        {**r, 'category': method_categories.get(r['method'], 'N.T.')}
        for r in results
    ]
    for result in results:
        method = result['method']
        category = result['category']
        
        values = []
        for metric in metrics:
            val = result[metric]
            if val is None:
                values.append('{--}' if use_siunitx else '--')
            elif isinstance(val, (int, float)):
                formatted_val = f"{val:.{decimal_places}f}"
                rounded_val = round(val, decimal_places)
                
                is_best = (
                    metric in best_values and 
                    round(best_values[metric], decimal_places) == rounded_val
                )
                is_second_best = (
                    metric in second_best_values and 
                    round(second_best_values[metric], decimal_places) == rounded_val
                )
                
                if use_siunitx:
                    # With siunitx, use different approach for formatting
                    if is_best:
                        values.append(f"{{\\bfseries {formatted_val}}}")
                    elif is_second_best:
                        values.append(f"{{\\underline{{{formatted_val}}}}}")
                    else:
                        values.append(formatted_val)
                else:
                    if is_best:
                        values.append(f"\\textbf{{{formatted_val}}}")
                    elif is_second_best:
                        values.append(f"\\underline{{{formatted_val}}}")
                    else:
                        values.append(formatted_val)
            else:
                values.append(str(val))
        
        row = f" {method} & {category} & " + ' & '.join(values) + ' \\\\'
        
        data_lines.append(row)

    # Construct table
    latex_table = f"""\\begin{{table}}[htb!]
\\caption{{{caption}}}
\\label{{{label}}}
\\centering
\\begin{{tabular}}{{{col_spec}}}
    \\toprule
    {header_line1}
    {header_line2}
    \\midrule
{chr(10).join(data_lines)}
    \\bottomrule
\\end{{tabular}}
\\end{{table}}"""
    
    return latex_table

def generate_latex_table_old_category(
    experiment_dirs: List[str],
    metrics: List[str],
    method_names: List[str],
    caption: str = "Sample quality in synthesis planning",
    label: str = "tab:results",
    metric_display_names: Optional[Dict[str, str]] = None,
    metric_display_names_line2: Optional[Dict[str, str]] = None,  # Second line of header
    decimal_places: int = 2,
    bold_best: Optional[Dict[str, str]] = None,
    results: Optional[List[Dict]] = None,
    method_categories: Optional[Dict[str, str]] = None,
    exact_match_metrics: Optional[List[str]] = None,
    use_siunitx: bool = False,  # For better number alignment
) -> str:
    if len(experiment_dirs) != len(method_names):
        raise ValueError("Number of experiment directories must match number of method names")
    
    # Default display names if not provided
    if metric_display_names is None:
        metric_display_names = {k: k for k in metrics}
    
    # Default second line (empty for single-line headers)
    print(metric_display_names_line2)
    if metric_display_names_line2 is None:
        metric_display_names_line2 = {k: '' for k in metrics}
    
    # Default categories if not provided
    if method_categories is None:
        method_categories = {name: 'N.T.' for name in method_names}
    
    # Collect data from all experiments
    if results is None:
        results = []
        for exp_dir, method_name in zip(experiment_dirs, method_names):
            eval_dir = os.path.join(exp_dir)
            
            csv_files = [f for f in os.listdir(eval_dir) if f.endswith('.csv')]
            dfs = [pd.read_csv(os.path.join(eval_dir, f)) for f in csv_files]
            df = pd.concat(dfs, ignore_index=True)
            
            # Calculate metrics (you'll need to import this function)
            from your_module import calculate_dataset_aggregates
            dataset_stats = calculate_dataset_aggregates(df)
            
            row_data = {'method': method_name}
            for metric in metrics:
                row_data[metric] = dataset_stats.get(metric, None)
            results.append(row_data)
    
    # Sort results by category
    category_order = {'T.': 0, 'N.T.': 1}
    results_with_category = [
        {**r, 'category': method_categories.get(r['method'], 'N.T.')}
        for r in results
    ]
    results_sorted = sorted(
        results_with_category,
        key=lambda x: (category_order.get(x['category'], 2), x['method'])
    )
    
    # Column specification (cleaner without vertical lines except after method)
    if use_siunitx:
        # S column for number alignment with siunitx
        num_col = 'S[table-format=1.2]'
    else:
        num_col = 'c'
    
    col_spec = f'l@{{\\hspace{{0.5em}}}}l@{{\\hspace{{0.8em}}}}' + num_col * len(metrics)
    
    # Build two-line header
    header_line1_cols = ['Category', 'Method']
    header_line2_cols = ['', '']
    
    for m in metrics:
        line1 = metric_display_names.get(m, m)
        line2 = metric_display_names_line2.get(m, '')
        #print(line1, line2)
        
        if line2:
            # Two-line header
            header_line1_cols.append(line1)
            header_line2_cols.append(line2)
        else:
            # Single-line header (use multirow)
            header_line1_cols.append(f"\\multirow{{2}}{{*}}{{{line1}}}")
            header_line2_cols.append('')
    
    header_line1 = ' & '.join(header_line1_cols) + ' \\\\'
    header_line2 = ' & '.join(header_line2_cols) + ' \\\\'
    
    # Find best and second-best values for each metric
    best_values = {}
    second_best_values = {}
    
    if bold_best:
        for metric in metrics:
            if metric in bold_best:
                values = [
                    r[metric] for r in results_sorted 
                    if r[metric] is not None and isinstance(r[metric], (int, float))
                ]
                if len(values) >= 1:
                    rounded_values = [round(v, decimal_places) for v in values]
                    unique_values = sorted(set(rounded_values), reverse=(bold_best[metric] == 'high'))
                    best_values[metric] = unique_values[0]
                    if len(unique_values) >= 2:
                        second_best_values[metric] = unique_values[1]
    
    # Generate data rows
    data_lines = []
    last_category = None
    category_row_counts = {}
    
    # Count rows per category
    for result in results_sorted:
        category = result['category']
        category_row_counts[category] = category_row_counts.get(category, 0) + 1
    
    # Generate rows
    category_first_row = {}
    for result in results_sorted:
        method = result['method']
        category = result['category']
        
        # Add midrule between categories
        if last_category == 'T.' and category == 'N.T.':
            data_lines.append('    \\midrule')
        
        is_first_in_category = (category not in category_first_row)
        if is_first_in_category:
            category_first_row[category] = True
        
        values = []
        for metric in metrics:
            val = result[metric]
            if val is None:
                values.append('{--}' if use_siunitx else '--')
            elif isinstance(val, (int, float)):
                formatted_val = f"{val:.{decimal_places}f}"
                rounded_val = round(val, decimal_places)
                
                is_best = (
                    metric in best_values and 
                    round(best_values[metric], decimal_places) == rounded_val
                )
                is_second_best = (
                    metric in second_best_values and 
                    round(second_best_values[metric], decimal_places) == rounded_val
                )
                
                if use_siunitx:
                    # With siunitx, use different approach for formatting
                    if is_best:
                        values.append(f"{{\\bfseries {formatted_val}}}")
                    elif is_second_best:
                        values.append(f"{{\\underline{{{formatted_val}}}}}")
                    else:
                        values.append(formatted_val)
                else:
                    if is_best:
                        values.append(f"\\textbf{{{formatted_val}}}")
                    elif is_second_best:
                        values.append(f"\\underline{{{formatted_val}}}")
                    else:
                        values.append(formatted_val)
            else:
                values.append(str(val))
        
        # Add category column with multirow
        if is_first_in_category:
            num_rows = category_row_counts[category]
            category_cell = f"\\multirow{{{num_rows}}}{{*}}{{\\rotatebox[origin=c]{{90}}{{{category}}}}}"
            row = f"    {category_cell} & {method} & " + ' & '.join(values) + ' \\\\'
        else:
            row = f"     & {method} & " + ' & '.join(values) + ' \\\\'
        
        data_lines.append(row)
        last_category = category
    
    # Construct table
    latex_table = f"""\\begin{{table}}[htb!]
\\caption{{{caption}}}
\\label{{{label}}}
\\centering
\\begin{{tabular}}{{{col_spec}}}
    \\toprule
    {header_line1}
    {header_line2}
    \\midrule
{chr(10).join(data_lines)}
    \\bottomrule
\\end{{tabular}}
\\end{{table}}"""
    
    return latex_table


# Example usage for your specific case:
def generate_search_metrics_table(experiment_dirs, method_names, method_categories):
    """Generate the search metrics table with proper two-line headers."""
    
    metrics = ['solve_rate', 'nodes_explored', 'model_calls', 'time_taken', 'routes_per_target']
    
    # First line of header
    metric_display_names = {
        'solve_rate': 'Solve rate',
        'nodes_explored': 'Nodes',
        'model_calls': 'Model',
        'time_taken': 'Time',
        'routes_per_target': 'Routes per'
    }
    
    # Second line of header
    metric_display_names_line2 = {
        'solve_rate': '($\\uparrow$)',
        'nodes_explored': 'explored ($\\downarrow$)',
        'model_calls': 'calls ($\\downarrow$)',
        'time_taken': 'taken ($\\downarrow$)',
        'routes_per_target': 'target ($\\uparrow$)'
    }
    
    bold_best = {
        'solve_rate': 'high',
        'nodes_explored': 'low',
        'model_calls': 'low',
        'time_taken': 'low',
        'routes_per_target': 'high'
    }
    
    return generate_latex_table(
        experiment_dirs=experiment_dirs,
        metrics=metrics,
        method_names=method_names,
        caption="Search metrics on USPTO-190 guided towards a specific starting material. T.=template-based, N.T.=non-template-based.",
        label="tab:search-metrics",
        metric_display_names=metric_display_names,
        metric_display_names_line2=metric_display_names_line2,
        decimal_places=2,
        bold_best=bold_best,
        method_categories=method_categories,
        use_siunitx=False  # Set to True if you have siunitx package
    )

# def generate_latex_table(
#     experiment_dirs: List[str],
#     metrics: List[str],
#     method_names: List[str],
#     caption: str = "Sample quality in synthesis planning",
#     label: str = "tab:results",
#     metric_display_names: Optional[Dict[str, str]] = None,
#     decimal_places: int = 2,
#     bold_best: Optional[Dict[str, str]] = None,
#     results: Optional[List[Dict]] = None,
#     method_categories: Optional[Dict[str, str]] = None,
#     exact_match_metrics: Optional[List[str]] = None,
# ) -> str:
#     """
#     Generate a LaTeX table from multiple experiment results.
    
#     Args:
#         experiment_dirs: List of paths to experiment result directories
#         metrics: List of metric names to include (must match keys in dataset_stats)
#         method_names: List of names for each method (same order as experiment_dirs)
#         caption: Table caption
#         label: LaTeX label for the table
#         metric_display_names: Optional dict mapping metric names to display names
#         decimal_places: Number of decimal places for formatting
#         bold_best: Optional dict mapping metric names to 'high' or 'low' to bold best values
#         results: Optional pre-computed results list
#         method_categories: Optional dict mapping method names to categories ('temp.' or 'non-temp.')
#         exact_match_metrics: Optional list of metrics that should appear before the vertical line
        
#     Returns:
#         LaTeX table string
#     """
#     if len(experiment_dirs) != len(method_names):
#         raise ValueError("Number of experiment directories must match number of method names")
    
#     # Default display names if not provided
#     if metric_display_names is None:
#         metric_display_names = {k: k for k in metrics}
    
#     # Default categories if not provided
#     if method_categories is None:
#         method_categories = {name: 'N.T.' for name in method_names}
    
#     # Collect data from all experiments
#     if results is None:
#         results = []
#         for exp_dir, method_name in zip(experiment_dirs, method_names):
#             # Read all CSV files from the evaluations directory
#             eval_dir = os.path.join(exp_dir)
            
#             csv_files = [f for f in os.listdir(eval_dir) if f.endswith('.csv')]
#             dfs = [pd.read_csv(os.path.join(eval_dir, f)) for f in csv_files]
#             df = pd.concat(dfs, ignore_index=True)
            
#             # Calculate metrics
#             dataset_stats = calculate_dataset_aggregates(df)
            
#             # Extract requested metrics
#             row_data = {'method': method_name}
#             for metric in metrics:
#                 if metric in dataset_stats:
#                     row_data[metric] = dataset_stats[metric]
#                 else:
#                     row_data[metric] = None
#             results.append(row_data)
    
#     # Sort results by category (temp. first, then non-temp.)
#     category_order = {'T.': 0, 'N.T.': 1}
#     results_with_category = [
#         {**r, 'category': method_categories.get(r['method'], 'N.T.')}
#         for r in results
#     ]
#     results_sorted = sorted(
#         results_with_category,
#         key=lambda x: (category_order.get(x['category'], 2), x['method'])
#     )
    
#     # Determine where to place vertical line
#     if exact_match_metrics is not None:
#         exact_match_count = len([m for m in metrics if m in exact_match_metrics])
#         # Column spec: category column + method column + exact match metrics | remaining metrics
#         col_spec = 'l|l' + 'c' * exact_match_count + '|' + 'c' * (len(metrics) - exact_match_count)
#     else:
#         col_spec = 'l|l' + 'c' * len(metrics)
    
#     # Header row
#     header_cols = ['Category', 'Method'] + [metric_display_names.get(m, m) for m in metrics]
#     header_line = ' & '.join(header_cols) + ' \\\\'
    
#     # Find best and second-best values for each metric across all methods
#     best_values = {}
#     second_best_values = {}
    
#     if bold_best:
#         for metric in metrics:
#             if metric in bold_best:
#                 values = [
#                     r[metric] for r in results_sorted 
#                     if r[metric] is not None and isinstance(r[metric], (int, float))
#                 ]
#                 if len(values) >= 1:
#                     # Round values before finding best/second-best
#                     rounded_values = [round(v, decimal_places) for v in values]
#                     # Get unique rounded values and sort them
#                     unique_values = sorted(set(rounded_values), reverse=(bold_best[metric] == 'high'))
#                     best_values[metric] = unique_values[0]
#                     if len(unique_values) >= 2:
#                         second_best_values[metric] = unique_values[1]
    
#     # Data rows
#     data_lines = []
#     last_category = None
#     category_row_counts = {}
    
#     # First pass: count rows per category
#     for result in results_sorted:
#         category = result['category']
#         category_row_counts[category] = category_row_counts.get(category, 0) + 1
    
#     # Second pass: generate rows
#     category_first_row = {}
#     for result in results_sorted:
#         method = result['method']
#         category = result['category']
        
#         # Add midrule after temp. methods
#         if last_category == 'temp.' and category == 'non-temp.':
#             data_lines.append('    \\midrule')
        
#         # Determine if this is the first row of a category
#         is_first_in_category = (category not in category_first_row)
#         if is_first_in_category:
#             category_first_row[category] = True
        
#         values = []
#         for metric in metrics:
#             val = result[metric]
#             if val is None:
#                 values.append('--')
#             elif isinstance(val, (int, float)):
#                 formatted_val = f"{val:.{decimal_places}f}"
                
#                 # Round the value for comparison (to match what's displayed)
#                 rounded_val = round(val, decimal_places)
                
#                 # Check if this is best or second-best across all methods (using rounded values)
#                 is_best = (
#                     metric in best_values and 
#                     round(best_values[metric], decimal_places) == rounded_val
#                 )
#                 is_second_best = (
#                     metric in second_best_values and 
#                     round(second_best_values[metric], decimal_places) == rounded_val
#                 )
                
#                 # Apply formatting
#                 if is_best:
#                     values.append(f"\\makecell{{\\textbf{{{formatted_val}}}}}")
#                 elif is_second_best:
#                     values.append(f"\\makecell{{\\underline{{{formatted_val}}}}}")
#                 else:
#                     values.append(f"\\makecell{{{formatted_val}}}")
#             else:
#                 values.append(f"\\makecell{{{val}}}")
        
#         # Add category column with multirow on first occurrence
#         if is_first_in_category:
#             category_display = category.replace('-', ' ').replace('_', ' ').title()
#             num_rows = category_row_counts[category]
#             category_cell = f"\\multirow{{{num_rows}}}{{*}}{{\\rotatebox[origin=c]{{90}}{{{category_display}}}}}"
#             row = f"    {category_cell} & {method} & " + ' & '.join(values) + ' \\\\'
#         else:
#             row = f"     & {method} & " + ' & '.join(values) + ' \\\\'
        
#         data_lines.append(row)
#         last_category = category
    
#     # Construct full table
#     latex_table = f"""\\begin{{table}}[htb!]
#     \\caption{{{caption}}}
# \\label{{{label}}}
# \\begin{{center}}
# \\begin{{tabular}}{{{col_spec}}}
#     \\toprule
#     {header_line}
#     \\midrule
# {chr(10).join(data_lines)}
#     \\bottomrule
# \\end{{tabular}}
# \\end{{center}}
# \\end{{table}}"""
    
#     return latex_table

def generate_latex_table_old(
    experiment_dirs: List[str],
    metrics: List[str],
    method_names: List[str],
    caption: str = "Sample quality in synthesis planning",
    label: str = "tab:results",
    metric_display_names: Optional[Dict[str, str]] = None,
    decimal_places: int = 2,
    bold_best: Optional[Dict[str, str]] = None,
    results: Optional[List[Dict]] = None,
) -> str:
    """
    Generate a LaTeX table from multiple experiment results.
    
    Args:
        experiment_dirs: List of paths to experiment result directories
        metrics: List of metric names to include (must match keys in dataset_stats)
        method_names: List of names for each method (same order as experiment_dirs)
        caption: Table caption
        label: LaTeX label for the table
        metric_display_names: Optional dict mapping metric names to display names
        decimal_places: Number of decimal places for formatting
        bold_best: Optional dict mapping metric names to 'high' or 'low' to bold best values
        
    Returns:
        LaTeX table string
    """
    if len(experiment_dirs) != len(method_names):
        raise ValueError("Number of experiment directories must match number of method names")
    
    # Default display names if not provided
    if metric_display_names is None:
        metric_display_names = {k: k for k in metrics}
    
    # Collect data from all experiments
    if results is None:
        results = []
        for exp_dir, method_name in zip(experiment_dirs, method_names):
            # Read all CSV files from the evaluations directory
            eval_dir = os.path.join(exp_dir)
            
            csv_files = [f for f in os.listdir(eval_dir) if f.endswith('.csv')]
            dfs = [pd.read_csv(os.path.join(eval_dir, f)) for f in csv_files]
            df = pd.concat(dfs, ignore_index=True)
            
            # Calculate metrics
            dataset_stats = calculate_dataset_aggregates(df)
            
            # Extract requested metrics
            row_data = {'method': method_name}
            for metric in metrics:
                if metric in dataset_stats:
                    row_data[metric] = dataset_stats[metric]
                else:
                    row_data[metric] = None
            results.append(row_data)
    
    # Generate LaTeX table
    num_cols = len(metrics) + 1  # +1 for method name column
    col_spec = 'l' + 'c' * len(metrics)
    
    # Header row
    header_cols = ['method'] + [metric_display_names.get(m, m) for m in metrics]
    header_line = ' & '.join(header_cols) + ' \\\\'
    
    # Find best values for each metric if bold_best is specified
    best_values = {}
    if bold_best:
        for metric in metrics:
            if metric in bold_best:
                values = [r[metric] for r in results if r[metric] is not None and isinstance(r[metric], (int, float))]
                if values:
                    if bold_best[metric] == 'high':
                        best_values[metric] = max(values)
                    elif bold_best[metric] == 'low':
                        best_values[metric] = min(values)
    
    # Data rows
    data_lines = []
    for result in results:
        method = result['method']
        values = []
        for metric in metrics:
            val = result[metric]
            if val is None:
                values.append('--')
            elif isinstance(val, (int, float)):
                formatted_val = f"{val:.{decimal_places}f}"
                # Bold if this is the best value
                if metric in best_values and abs(val - best_values[metric]) < 1e-10:
                    values.append(f"\\makecell{{\\textbf{{{formatted_val}}}}}")
                else:
                    values.append(f"\\makecell{{{formatted_val}}}")
            else:
                values.append(f"\\makecell{{{val}}}")
        
        row = f"    {method} & " + ' & '.join(values) + ' \\\\'
        data_lines.append(row)
    
    # Construct full table
    latex_table = f"""\\begin{{table}}[htb!]
    \\caption{{{caption}}}
\\label{{{label}}}
\\begin{{center}}
\\begin{{tabular}}{{{col_spec}}}
    {header_line}
\\hline \\\\
{chr(10).join(data_lines)}
\\end{{tabular}}
\\end{{center}}
\\end{{table}}"""
    
    return latex_table

def save_latex_table(
    latex_table: str,
    output_path: str,
    standalone: bool = False
) -> None:
    """
    Save LaTeX table to a file.
    
    Args:
        latex_table: LaTeX table string to save
        output_path: Path where to save the .tex file
        standalone: If True, wraps table in a minimal standalone document
    """
    if standalone:
        content = f"""\\documentclass{{article}}
\\usepackage{{booktabs}}
\\usepackage{{makecell}}
\\begin{{document}}

{latex_table}

\\end{{document}}"""
    else:
        content = latex_table
    
    with open(output_path, 'w') as f:
        f.write(content)
    
    print(f"LaTeX table saved to: {output_path}")

def select_best_search_experiment_per_product_tanimoto(list_dfs, list_experiment_names):
    """
    For each product, select the search experiment with best results based on hierarchical criteria:
    1. Route contains starting material
    2. Exact route was found (product_matches == True)
    3. Route with max number of topk matches
    4. Total number of topk matches across all routes
    5. Route with highest avg tanimoto similarity
    6. A route was found (solved == True)
    
    Args:
        list_dfs: List of dataframes containing search results
        list_experiment_names: List of experiment names corresponding to dataframes
        
    Returns:
        best_data: DataFrame with best experiment results for each product
        best_experiments: DataFrame with metadata about which experiment was selected
    """
    # Combine all dataframes
    combined = []
    for df, experiment_name in zip(list_dfs, list_experiment_names):
        df_copy = df.copy()
        df_copy['experiment_name'] = experiment_name
        combined.append(df_copy)
    
    all_data = pd.concat(combined, ignore_index=True)
    
    # Compute metrics per product per experiment
    def compute_metrics(group):
        # Criterion 1: Route contains starting material
        if 'original_starting_material' in group.columns and 'reactant_predictions' in group.columns and 'sample_route_idx' in group.columns:
            route_contains_sm = group.groupby('sample_route_idx').apply(
                lambda route_group: route_group.apply(
                    lambda row: (
                        row['original_starting_material'] in row['reactant_predictions'].split('.')
                        if pd.notna(row['original_starting_material']) and pd.notna(row['reactant_predictions'])
                        else False
                    ),
                    axis=1
                ).any(),  # True if ANY row in the route has the SM
                include_groups=False
            ).any()  # True if ANY route contains the SM
        else:
            route_contains_sm = False
        
        # Criterion 2: Exact route found (all reactions in at least one route have topk=True)
        if 'topk' in group.columns and 'sample_route_idx' in group.columns:
            has_exact_route = group.groupby('sample_route_idx')['topk'].all().any()
        else:
            has_exact_route = False
        
        # Criterion 3: Maximum number of topk matches in a single route
        if 'topk' in group.columns and 'sample_route_idx' in group.columns:
            route_topk_counts = group.groupby('sample_route_idx')['topk'].sum()
            max_topk_matches_per_route = route_topk_counts.max() if len(route_topk_counts) > 0 else 0
        else:
            max_topk_matches_per_route = 0
        
        # Criterion 4: Total number of topk matches across all routes
        num_topk_matches = group['topk'].sum() if 'topk' in group.columns else 0
        
        # Criterion 5: Route with highest avg tanimoto similarity
        if 'pred_tanimoto_to_target' in group.columns and 'sample_route_idx' in group.columns:
            # For each route, compute average tanimoto
            route_avg_tanimoto = group.groupby('sample_route_idx')['pred_tanimoto_to_target'].mean()
            # Take the maximum average across all routes
            max_avg_tanimoto = route_avg_tanimoto.max() if len(route_avg_tanimoto) > 0 else 0
        else:
            max_avg_tanimoto = 0
        
        # Criterion 6: Route was found
        route_found = group['solved'].any() if 'solved' in group.columns else False
        
        return pd.Series({
            'route_contains_sm': route_contains_sm,
            'has_exact_route': has_exact_route,
            'max_topk_matches_per_route': max_topk_matches_per_route,
            'num_topk_matches': num_topk_matches,
            'max_avg_tanimoto': max_avg_tanimoto,
            'route_found': route_found
        })
    
    # Group by product and experiment, compute metrics
    metrics = all_data.groupby(
        ['target_idx', 'experiment_name'], 
        as_index=False
    ).apply(compute_metrics, include_groups=False)
    
    # Sort by criteria in order of preference
    metrics_sorted = metrics.sort_values(
        by=[
            'target_idx',
            'route_contains_sm',
            'has_exact_route',
            'max_topk_matches_per_route',
            'num_topk_matches',
            'max_avg_tanimoto',
            'route_found'
        ],
        ascending=[True, False, False, False, False, False, False]
    )
    
    # Take best experiment per product
    best_experiments = metrics_sorted.groupby('target_idx').first().reset_index()
    
    # Join back to get full data for best experiments
    best_data = all_data.merge(
        best_experiments[['target_idx', 'experiment_name']], 
        on=['target_idx', 'experiment_name']
    )
    
    return best_data, best_experiments

def calculate_per_target_metrics(df: pd.DataFrame) -> pd.DataFrame:
    """Calculate metrics for each target, including unsolved targets as zeros"""
    target_metrics = []

    # Process solved targets
    for target_idx, target_group in df.groupby('target_idx'):
        true_length = target_group['true_route_length'].iloc[0]
        
        metrics = {
            'target_idx': target_idx,
            'original_target': target_group['original_target'].iloc[0],
            'original_starting_material': target_group['original_starting_material'].iloc[0],
            'true_route_length': true_length,
            'solved': target_group['solved'].iloc[0],
            'num_routes': target_group['sample_route_idx'].nunique(),
            'num_nonoverlapping_routes': target_group['num_unique_routes'].iloc[0],
            
            # Search metrics
            'avg_nodes_explored': target_group['num_nodes_explored'].mean(), # need the info for the unsolved ones
            'avg_model_calls': target_group['num_model_calls'].mean(),
            'avg_time_taken': target_group['time_taken'].mean(),
            'avg_search_iterations': target_group['num_search_iterations'].mean(),
            
            # Route length
            'avg_route_length': target_group.groupby('sample_route_idx').size().mean(),
            'route_length_diff': abs(target_group.groupby('sample_route_idx').size().mean() - true_length),
            'route_length_ratio': target_group.groupby('sample_route_idx').size().mean() / true_length if true_length > 0 else np.nan,
            'has_exact_length_match': (target_group.groupby('sample_route_idx').size() == true_length).any(),
            
            # Reaction quality
            'avg_exact_match': target_group['topk'].mean(),
            'avg_round_trip': target_group['round_trip_accuracy'].mean(),
            'avg_rxn_type_match': (target_group['pred_class'] == target_group['true_class']).mean(),
            'avg_tanimoto_to_target': target_group['pred_tanimoto_to_target'].mean(),
            'avg_tanimoto_to_sm': target_group['pred_tanimoto_to_starting_material'].mean(),
            'avg_classifier_confidence': target_group['classifier_confidence'].mean(),
            'avg_rxn_name_match': (target_group['rxn_insight_NAME']!='OtherReaction').mean(),
 
            # Perfect routes
            'has_perfect_exact_match_route': target_group.groupby('sample_route_idx')['topk'].all().any(),
            'has_perfect_round_trip_route': target_group.groupby('sample_route_idx')['round_trip_accuracy'].all().any(),
            'has_perfect_rxn_name_match_route': target_group.groupby('sample_route_idx').apply(
                lambda x: (x['rxn_insight_NAME']!='OtherReaction').all(),
                include_groups=False
            ).any(),
            'num_perfect_exact_match_routes': target_group.groupby('sample_route_idx')['topk'].all().sum(),
            'num_perfect_round_trip_routes': target_group.groupby('sample_route_idx')['round_trip_accuracy'].all().sum(),

            # sm steering metrics
            'solve_rate_with_sm': target_group.groupby(['sample_route_idx']).apply(
                                                lambda group: group.apply(
                                                    lambda row: (
                                                        row['original_starting_material'] in row['reactant_predictions'].split('.')
                                                        if pd.notna(row['original_starting_material']) and pd.notna(row['reactant_predictions'])
                                                        else False
                                                    ),
                                                    axis=1
                                                ).any()  # True if ANY row in the route has the SM
                                            ).any(),
            'num_routes_with_sm': target_group.groupby(['sample_route_idx']).apply(
                                                lambda group: group.apply(
                                                    lambda row: (
                                                        row['original_starting_material'] in row['reactant_predictions'].split('.')
                                                        if pd.notna(row['original_starting_material']) and pd.notna(row['reactant_predictions'])
                                                        else False
                                                    ),
                                                    axis=1
                                                ).any()  # True if ANY row in the route has the SM
                                            ).mean(),
            # computing the number of unique routes requires finding nonoverlapping routes, too much work for now
        }
        target_metrics.append(metrics)
    
    return pd.DataFrame(target_metrics).sort_values('target_idx')

def calculate_dataset_aggregates(df: pd.DataFrame) -> Dict[str, float]:
    """Calculate dataset-level aggregates including all targets"""
    df = df.fillna({'true_reactants': 'N'})
    df = df.fillna({'true_product_cano': 'Br'})
    df = df.fillna({'true_reaction_cano': 'N>>Br'})
    target_df = calculate_per_target_metrics(df)
    total_targets = len(target_df)
    solved_targets = target_df['solved'].sum()
    solved_targets_indices = target_df[target_df['solved'] == 1]['target_idx'].tolist()
    solved_with_sm_indices = target_df[target_df['solve_rate_with_sm'] == 1]['target_idx'].tolist()

    # starting material specific metrics
    # solve_rate_with_sm: solve rate with sm in route
    # num_of_routes_with_sm: number of routes with sm in route
    # num_of_unique_routes_with_sm: number of unique routes with sm in route
    metrics = {
        # Basic counts
        'total_targets': total_targets,
        'solved_targets': solved_targets,
        'solve_rate': solved_targets / total_targets,
        'solve_rate_with_sm': target_df['solve_rate_with_sm'].mean(),
        'solved_targets_indices': solved_targets_indices,
        'solved_with_sm_indices': solved_with_sm_indices,
        
        # Search metrics (averaged over ALL targets)
        'avg_nodes_explored': target_df['avg_nodes_explored'].mean(),
        'avg_model_calls': target_df['avg_model_calls'].mean(),
        'avg_time_taken': target_df['avg_time_taken'].mean(),
        'avg_routes_per_target': target_df['num_routes'].mean(),
        'avg_num_nonoverlapping_routes': target_df['num_nonoverlapping_routes'].mean(),
        
        # Dataset-level quality (total correct / total reactions)
        'dataset_avg_exact_match': df['topk'].sum() / len(df),
        'dataset_avg_round_trip': df['round_trip_accuracy'].sum() / len(df),
        'dataset_avg_rxn_type_match': (df['pred_class'] == df['true_class']).sum() / len(df),
        'dataset_avg_tanimoto_to_target': df['pred_tanimoto_to_target'].mean(),
        'dataset_avg_tanimoto_to_sm': df['pred_tanimoto_to_starting_material'].mean(),
        'dataset_avg_rxn_name_match': (df['rxn_insight_NAME']!='OtherReaction').sum() / len(df),
        
        # Per-target average quality (mean across ALL targets)
        'target_avg_exact_match': target_df['avg_exact_match'].mean(),
        'target_avg_round_trip': target_df['avg_round_trip'].mean(),
        'target_avg_rxn_type_match': target_df['avg_rxn_type_match'].mean(),
        'target_avg_tanimoto_to_target': target_df['avg_tanimoto_to_target'].mean(),
        'target_avg_tanimoto_to_sm': target_df['avg_tanimoto_to_sm'].mean(),
        'target_avg_rxn_name_match': target_df['avg_rxn_name_match'].mean(),
        'target_avg_contains_starting_material': target_df['solve_rate_with_sm'].mean(),
        'num_routes_with_sm': target_df['num_routes_with_sm'].mean(),
        
        # Perfect route metrics
        'avg_targets_with_exact_match_route': target_df['has_perfect_exact_match_route'].mean(),
        'avg_targets_with_round_trip_route': target_df['has_perfect_round_trip_route'].mean(),
        'avg_targets_with_rxn_name_match_route': target_df['has_perfect_rxn_name_match_route'].mean(),
        
        # Route length
        'avg_predicted_route_length': target_df[target_df['solved'] == 1]['avg_route_length'].mean(),
        'avg_true_route_length': target_df[target_df['solved'] == 1]['true_route_length'].mean(),
        'avg_route_length_diff': target_df[target_df['solved'] == 1]['route_length_diff'].mean(),
        'avg_targets_with_exact_length_match': target_df['has_exact_length_match'].mean(),
    }
    
    return metrics, target_df

def select_best_experiment_manual_synthesis_per_product(
    list_dfs, 
    list_experiment_names, 
    criteria: str = 'reaction_type'
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    if criteria == 'reaction_type':
        return select_best_experiment_manual_synthesis_reaction_type(list_dfs, list_experiment_names)
    elif criteria == 'tanimoto':
        return select_best_experiment_manual_synthesis_tanimoto(list_dfs, list_experiment_names)
    elif criteria == 'oracle':
        return select_best_experiment_manual_synthesis_oracle(list_dfs, list_experiment_names)
    else:
        raise ValueError(f'Criteria {criteria} not supported')

def select_best_experiment_manual_synthesis_tanimoto(
    list_dfs, 
    list_experiment_names
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    For each product, select the experiment with best results based on tanimoto similarity to the target.
    """
    # Combine all dataframes
    combined = []
    for guided_df, experiment_name in zip(list_dfs, list_experiment_names):
        df = guided_df.copy()
        df['experiment_name'] = experiment_name
        combined.append(df)
    
    all_data = pd.concat(combined, ignore_index=True)
    
    # Compute metrics per product per experiment
    def compute_metrics(group):
        # Find rank of exact match using topk column
        # Average samples meeting ground truth class
        avg_tanimoto_to_starting_material = (group['pred_tanimoto_to_starting_material']).mean()
        
        # Average samples with round trip matches
        avg_round_trip = group['round_trip_accuracy'].mean()
        
        # Average samples with identified rxn name
        avg_has_name = (group['rxn_insight_NAME']!='OtherReaction').mean()
        
        return pd.Series({
            'avg_tanimoto_to_starting_material': avg_tanimoto_to_starting_material,
            'avg_round_trip': avg_round_trip,
            'avg_has_name': avg_has_name
        })
    
    # Group by product and experiment, compute metrics
    metrics = all_data.groupby(['product_smi', 'experiment_name'], as_index=False).apply(compute_metrics, include_groups=False)
    
    # Sort by criteria
    metrics_sorted = metrics.sort_values(
        by=['product_smi', 'avg_tanimoto_to_starting_material', 'avg_round_trip', 'avg_has_name'],
        ascending=[True, False, False, False] # [True, True, False, False, False]
    )
    
    # metrics_sorted = metrics.sort_values(
    #     by=['product_smi', 'avg_round_trip', 'avg_correct_class'],
    #     ascending=[True, True, False] # [True, True, False, False, False]
    # )
    
    # Take best experiment per product
    best_experiments = metrics_sorted.groupby('product_smi').first().reset_index()
    
    # Join back to get full data for best experiments
    best_data = all_data.merge(
        best_experiments[['product_smi', 'experiment_name']], 
        on=['product_smi', 'experiment_name']
    )
    
    return best_data, best_experiments

def filter_manual_synthesis_results_by_criteria(
    list_dfs, 
    list_experiment_names,
    criteria: str = 'reaction_type',
    criteria_threshold: float = 0.2,
    apply_filter: bool = False
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    For each product, select the experiment with best results based on tanimoto similarity to the target.
    """
    # Combine all dataframes
    combined = []
    for guided_df, experiment_name in zip(list_dfs, list_experiment_names):
        df = guided_df.copy()
        df['experiment_name'] = experiment_name
        combined.append(df)
    
    all_data = pd.concat(combined, ignore_index=True)
    # deduplicate by product_smi and reactant_predictions
    #all_data = all_data.drop_duplicates(subset=['product_smi', 'true_reactants', 'reactant_predictions'])
    if apply_filter:
        if criteria == 'reaction_type':
            all_data = all_data[(all_data['pred_class'] == all_data['true_class']) | (all_data['topk'] == True)]
        elif criteria == 'tanimoto':
            all_data = all_data[(all_data['pred_tanimoto_to_target'] > criteria_threshold) | (all_data['topk'] == True)]
        else:
            raise ValueError(f'Criteria {criteria} not supported')

    return all_data

def select_best_experiment_manual_synthesis_reaction_type(
    list_dfs, 
    list_experiment_names
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    For each product, select the experiment with best results based on tanimoto similarity to the target.
    """
    # Combine all dataframes
    combined = []
    for guided_df, experiment_name in zip(list_dfs, list_experiment_names):
        df = guided_df.copy()
        df['experiment_name'] = experiment_name
        combined.append(df)
    
    all_data = pd.concat(combined, ignore_index=True)
    
    # Compute metrics per product per experiment
    def compute_metrics(group):
        # Find rank of exact match using topk column
        exact_match_mask = group['topk'] == True
        if exact_match_mask.any():
            # First True occurrence is highest rank (1-indexed)
            exact_match_rank = exact_match_mask.idxmax() - group.index[0] + 1
        else:
            exact_match_rank = float('inf')  # No match gets worst rank
        
        # Average samples meeting ground truth class
        avg_correct_class = (group['pred_class'] == group['true_class']).mean()
        
        # Average samples with round trip matches
        avg_round_trip = group['round_trip_accuracy'].mean()
        
        # Average samples with identified rxn name
        avg_has_name = (group['rxn_insight_NAME']!='OtherReaction').mean()
        
        return pd.Series({
            'exact_match_rank': exact_match_rank,
            'avg_correct_class': avg_correct_class,
            'avg_round_trip': avg_round_trip,
            'avg_has_name': avg_has_name
        })
    
    # Group by product and experiment, compute metrics
    metrics = all_data.groupby(['product_smi', 'experiment_name'], as_index=False).apply(compute_metrics, include_groups=False)
    
    # Sort by criteria
    metrics_sorted = metrics.sort_values(
        by=['product_smi', 'exact_match_rank', 'avg_correct_class', 'avg_round_trip', 'avg_has_name'],
        ascending=[True, True, False, False, False] # [True, True, False, False, False]
    )
    
    # metrics_sorted = metrics.sort_values(
    #     by=['product_smi', 'avg_round_trip', 'avg_correct_class'],
    #     ascending=[True, True, False] # [True, True, False, False, False]
    # )
    
    # Take best experiment per product
    best_experiments = metrics_sorted.groupby('product_smi').first().reset_index()
    
    # Join back to get full data for best experiments
    best_data = all_data.merge(
        best_experiments[['product_smi', 'experiment_name']], 
        on=['product_smi', 'experiment_name']
    )
    
    return best_data, best_experiments

def select_best_experiment_manual_synthesis_oracle(list_dfs, list_experiment_names):
    """
    For each product, select the experiment with best results based on hierarchical criteria.
    """
    # Combine all dataframes
    combined = []
    for guided_df, experiment_name in zip(list_dfs, list_experiment_names):
        df = guided_df.copy()
        df['experiment_name'] = experiment_name
        combined.append(df)
    
    all_data = pd.concat(combined, ignore_index=True)
    
    # Compute metrics per product per experiment
    def compute_metrics(group):
        # Find rank of exact match using topk column
        exact_match_mask = group['topk'] == True
        if exact_match_mask.any():
            # First True occurrence is highest rank (1-indexed)
            exact_match_rank = exact_match_mask.idxmax() - group.index[0] + 1
        else:
            exact_match_rank = float('inf')  # No match gets worst rank
        
        # Average samples meeting ground truth class
        avg_correct_class = (group['pred_class'] == group['true_class']).mean()
        
        # Average samples with round trip matches
        avg_round_trip = group['round_trip_accuracy'].mean()
        
        # Average samples with identified rxn name
        avg_has_name = (group['rxn_insight_NAME']!='OtherReaction').mean()
        
        return pd.Series({
            'exact_match_rank': exact_match_rank,
            'avg_correct_class': avg_correct_class,
            'avg_round_trip': avg_round_trip,
            'avg_has_name': avg_has_name
        })
    
    # Group by product and experiment, compute metrics
    metrics = all_data.groupby(['product_smi', 'experiment_name'], as_index=False).apply(compute_metrics, include_groups=False)
    
    # Sort by criteria
    metrics_sorted = metrics.sort_values(
        by=['product_smi', 'exact_match_rank', 'avg_correct_class', 'avg_round_trip', 'avg_has_name'],
        ascending=[True, True, False, False, False] # [True, True, False, False, False]
    )
    
    # metrics_sorted = metrics.sort_values(
    #     by=['product_smi', 'avg_round_trip', 'avg_correct_class'],
    #     ascending=[True, True, False] # [True, True, False, False, False]
    # )
    
    # Take best experiment per product
    best_experiments = metrics_sorted.groupby('product_smi').first().reset_index()
    
    # Join back to get full data for best experiments
    best_data = all_data.merge(
        best_experiments[['product_smi', 'experiment_name']], 
        on=['product_smi', 'experiment_name']
    )
    
    return best_data, best_experiments

def simplify_metrics(metrics):
    simplified_metrics = {}
    for m in metrics:
        if type(metrics[m])==dict:
            for k in metrics[m]:
                simplified_metrics[m+'_'+str(k)] = metrics[m][k]
        else:
            simplified_metrics[m] = metrics[m]
    return simplified_metrics

def format_latex_row(method_name, metrics_dict, metrics_list):
    values = []
    for m in metrics_list:
        mean = metrics_dict[m]
        if m + '_std' in metrics_dict:
            std = metrics_dict[m + '_std']
            values.append(f"\\makecell{{{mean:.2f} \\\\ {{\\scriptsize $\\pm${std:.0e}}}}}")
        else:
            values.append(f"\\makecell{{{mean:.2f}}}")
    
    return f"{method_name} & " + " & ".join(values) + " \\\\"


def load_single_step_results(experiment_dir: str) -> pd.DataFrame:
    '''
        Load the single-step results from the experiment directory.
    '''
    # read all files
    files = [f for f in os.listdir(experiment_dir) if f.endswith('.csv') and f.startswith('eval')]
    dfs = []
    for file in files:
        df = pd.read_csv(os.path.join(experiment_dir, file))
        # TODO: delete this after running it
        #df['round_trip_accuracy'] = df.apply(lambda x: x['round_trip_accuracy'] | x['topk'], axis=1)
        # df['topk'] = df['topk'].apply(lambda x:  eval(x)[0]).astype(bool)
        # df['pred_tanimoto_to_starting_material'] = df['pred_tanimoto_to_starting_material'].apply(lambda x: eval(x)[0])
        # df['pred_tanimoto_to_target'] = df['pred_tanimoto_to_target'].apply(lambda x: eval(x)[0])
        # df['topk'] = df.apply(
        #     lambda x: compare_reactant_smiles(
        #         x['true_reactants'],
        #         x['reactant_predictions']
        #     ), axis=1)
        # df['round_trip_accuracy'] = df.apply(lambda x: x['product_smi'] in x['round_trip_results'], axis=1)
        # if 'rxn_insight_class' in df.columns:
        #     df.rename(columns={'rxn_insight_class': 'pred_class'}, inplace=True)
        #df.to_csv(os.path.join(experiment_dir, file), index=False)
        dfs.append(df)
    df = pd.concat(dfs, ignore_index=True)
    return df

def compare_experiment_pair(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    exp1_name: str,
    exp2_name: str,
    key1: str = 'topk',
    key2: str = 'round_trip_accuracy'
) -> Dict[str, Any]:
    """
    Detailed comparison between two specific experiments.
    Useful for analyzing the effect of changing one parameter.
    """
    # Products with ground truth matches in each experiment
    products_exp1 = set(df1[df1[key1]]['product_smi'].unique()) if key1 in df1.columns else set()
    products_exp2 = set(df2[df2[key2]]['product_smi'].unique()) if key2 in df2.columns else set()
    total_products_exp1 = len(df1['product_smi'].unique())
    total_products_exp2 = len(df2['product_smi'].unique())
    
    # Set operations
    gained = products_exp2 - products_exp1  # In exp2 but not exp1
    lost = products_exp1 - products_exp2    # In exp1 but not exp2
    maintained = products_exp1 & products_exp2
    
    return {
        'exp1_name': exp1_name,
        'exp2_name': exp2_name,
        'gained_products': list(gained),
        'lost_products': list(lost),
        'maintained_products': list(maintained),
        'num_gained': len(gained),
        'num_lost': len(lost),
        'num_maintained': len(maintained),
        'net_change': len(gained) - len(lost),
        'exp1_total': len(products_exp1),
        'exp2_total': len(products_exp2),
        'total_products_exp1': total_products_exp1,
        'total_products_exp2': total_products_exp2,
    }

def analyze_ground_truth_coverage_across_experiments(
    experiment_dfs: List[pd.DataFrame],
    experiment_names: Optional[List[str]] = None
) -> Dict[str, Any]:
    """
    Analyze which products have ground truth matches across different experiments.
    
    Args:
        experiment_dfs: List of DataFrames, one per experiment
        experiment_names: Optional list of experiment names (must match length of experiment_dfs)
                         If None, will use 'exp_0', 'exp_1', etc.
        
    Returns:
        Dictionary containing:
        - coverage_matrix: DataFrame showing which experiments found each product
        - unique_to_experiment: Products found only in specific experiments
        - summary_stats: Overall statistics across all experiments
    """
    # Generate default names if not provided
    if experiment_names is None:
        experiment_names = [f'exp_{i}' for i in range(len(experiment_dfs))]
    
    # Validate inputs
    if len(experiment_names) != len(experiment_dfs):
        raise ValueError(f"Length of experiment_names ({len(experiment_names)}) must match "
                        f"length of experiment_dfs ({len(experiment_dfs)})")
    
    results = {
        'experiment_names': experiment_names,
        'coverage_matrix': None,
        'unique_to_experiment': {},
        'summary_stats': {}
    }
    
    # Build coverage matrix: which experiments found ground truth for each product
    coverage_data = []
    all_products = set()
    
    for exp_name, df in zip(experiment_names, experiment_dfs):
        if 'product_smi' in df.columns and 'topk' in df.columns:
            products_with_match = set(df[df['topk']]['product_smi'].unique())
            all_products.update(df['product_smi'].unique())
            
            for product in products_with_match:
                coverage_data.append({
                    'product_smi': product,
                    'experiment': exp_name,
                    'has_match': True
                })
    
    # Create coverage matrix
    if coverage_data:
        coverage_df = pd.DataFrame(coverage_data)
        coverage_matrix = coverage_df.pivot_table(
            index='product_smi',
            columns='experiment',
            values='has_match',
            fill_value=False,
            aggfunc='any'
        )
        # Reorder columns to match experiment_names order
        coverage_matrix = coverage_matrix[experiment_names]
        results['coverage_matrix'] = coverage_matrix
        
        # Find products unique to each experiment
        for exp_name in experiment_names:
            if exp_name in coverage_matrix.columns:
                unique_products = coverage_matrix[
                    coverage_matrix[exp_name] & 
                    ~coverage_matrix.drop(columns=[exp_name]).any(axis=1)
                ].index.tolist()
                
                results['unique_to_experiment'][exp_name] = {
                    'products': unique_products,
                    'count': len(unique_products)
                }
        
        # Summary statistics
        results['summary_stats'] = {
            'total_unique_products': len(all_products),
            'products_with_any_match': len(coverage_matrix),
            'products_found_by_all': coverage_matrix.all(axis=1).sum(),
            'products_found_by_one_only': (coverage_matrix.sum(axis=1) == 1).sum(),
        }
    
    return results

def get_classifier_score(smiles, config):
    '''
        This function gets the classifier score for a list of SMILES strings.

        Args:
            smiles: the list of SMILES strings
            config: the config object

        Returns:
            output: the classifier output
            confidence: the classifier confidence
    '''
    vocab = get_vocab_from_trained_model(config.classifier_guidance.onmt_checkpoint_path)
    all_seq_ids = []
    if len(smiles)==0:
        return None, None
    non_parsed_smi = []
    for smi_idx, smi in enumerate(smiles):
        ids = turn_seq_to_ids(smi, onmt_checkpoint_path=config.classifier_guidance.onmt_checkpoint_path)
        if ids is None:
            print(f'ids is None for smiles {smi} at index {smi_idx}')
            non_parsed_smi.append(smi_idx)
            #continue
            ids = [vocab.index('<unk>')]*10 
            print(f'ids {ids} for smiles {smi}')
        all_seq_ids.append(torch.tensor(ids))
    # if len(all_seq_ids)==0:
    #     return [-1]*len(smiles), [-1]*len(smiles)
    #seq_ids = [turn_seq_to_ids(smi, onmt_checkpoint_path=config.classifier_guidance.onmt_checkpoint_path) for smi in smiles]
    # pad
    print(f'all_seq_ids {len(all_seq_ids)} in get_classifier_score')
    if len(all_seq_ids)==1:
        seq_ids = ids
    seq_ids = torch.nn.utils.rnn.pad_sequence(all_seq_ids, 
                                              batch_first=True, 
                                              padding_value=vocab.index('<blank>'))
    property_model = PropertyPredictor(config, len(vocab))
    # load checkpoint
    checkpoint_path = os.path.join(PROJECT_ROOT,
                                   'checkpoints',
                                   config.classifier_guidance.checkpoint_path)
    property_checkpoint = torch.load(checkpoint_path, map_location=device)
    property_model.load_state_dict(property_checkpoint['model_state_dict'])
    property_model = property_model.to(device)
    property_model.eval()
    with torch.no_grad():
        seq_ids = seq_ids.to(device)
        out = property_model(seq_ids)
        if not config.classifier_guidance.as_regression:
            confidence, output = F.softmax(out, dim=1).max(dim=1)
            confidence[non_parsed_smi] = -1
            output[non_parsed_smi] = -1
        else:
            confidence = -1*np.ones(len(smiles))
            output = out
            if config.classifier_guidance.normalize_prediction:
                output = output*property_checkpoint['target_std'] + property_checkpoint['target_mean']
            output[non_parsed_smi] = -1
    return output, confidence

def get_classifier_scores_batch(smiles_list, config, batch_size=32):
    """
    Compute classifier scores for a list of SMILES in batches.
    
    Args:
        smiles_list: List of SMILES strings
        config: Configuration object
        batch_size: Batch size for processing
    
    Returns:
        outputs: List of predictions
        confidences: List of confidence scores (None if regression)
    """
    if len(smiles_list) == 0:
        return [], []
    vocab = get_vocab_from_trained_model(config.classifier_guidance.onmt_checkpoint_path)
    # Load model once
    property_model = PropertyPredictor(config, len(vocab))
    checkpoint_path = os.path.join(
        PROJECT_ROOT,
        'checkpoints',
        config.classifier_guidance.checkpoint_path
    )
    property_checkpoint = torch.load(checkpoint_path, map_location=device)
    property_model.load_state_dict(property_checkpoint['model_state_dict'])
    property_model = property_model.to(device)
    property_model.eval()
    all_outputs = []
    all_confidences = []
    # Process in batches
    for i in range(0, len(smiles_list), batch_size):
        batch_smiles = smiles_list[i:i+batch_size]
        # Convert SMILES to IDs
        non_parsed_smi = []
        batch_seq_ids = []
        for smi_idx, smi in enumerate(batch_smiles):
            ids = turn_seq_to_ids(
                smi,
                onmt_checkpoint_path=config.classifier_guidance.onmt_checkpoint_path
            )
            if ids is None:
                non_parsed_smi.append(smi_idx)
                print(f'ids is None for smiles {smi} at index {smi_idx}')
                ids = [vocab.index('<unk>')] * 10
                print(f'ids {ids} for smiles {smi}')
            batch_seq_ids.append(torch.tensor(ids))
        # Pad sequences
        seq_ids = torch.nn.utils.rnn.pad_sequence(
            batch_seq_ids,
            batch_first=True,
            padding_value=vocab.index('<blank>')
        )
        # Get predictions
        with torch.no_grad():
            seq_ids = seq_ids.to(device)
            out = property_model(seq_ids)
            if not config.classifier_guidance.as_regression:
                confidence, output = F.softmax(out, dim=1).max(dim=1)
                confidence[non_parsed_smi] = -1
                output[non_parsed_smi] = -1
            else:
                confidence = -1*torch.ones(len(batch_smiles))
                output = out.squeeze(-1)
                if config.classifier_guidance.normalize_prediction:
                    output = output*property_checkpoint['target_std']+property_checkpoint['target_mean']
                output[non_parsed_smi] = -1
            all_confidences.extend(confidence.cpu().tolist())
            all_outputs.extend(output.cpu().tolist())
    return all_outputs, all_confidences

def evaluate_results_for_one_batch(df, config):
    '''
        This function evaluates the results for one batch.
    '''
    # NOTE: this first check is more relevant to search-based multi step synthesis
    if 'true_product_cano' in df.columns:
        df['product_matches'] = df.apply(lambda x: x['true_product_cano'] == x['product_smi'], axis=1)
    # remove dative bonds, they cause issues in some versions of rdkit
    df['reactant_predictions'] = df['reactant_predictions'].apply(remove_dative_bonds_one_molecule)
    # check if the predicted reactants are bb
    bb_path = os.path.join(
        PROJECT_ROOT,
        'data', 
        'desp_data', 
        'canon_building_block_mol2idx_no_isotope.json'
    )
    bbs = json.load(open(bb_path, 'r', encoding='utf-8'))
    df['all_pred_reactants_are_bbs'] = df['reactant_predictions'].apply(
        lambda x: all(m in bbs for m in x.split('.'))
    )
    # compute pred tanimoto to target
    df['pred_tanimoto_to_target'] = df.apply(
        lambda x: get_tanimoto(
                x['reactant_predictions'],
                x['original_target']
            ), axis=1
    )
    # compute pred tanimoto to starting material
    df['pred_tanimoto_to_starting_material'] = df.apply(
        lambda x:get_tanimoto(
                x['reactant_predictions'],
                x['original_starting_material']
            ), axis=1
    )
    df['topk'] = df.apply(
        lambda x: compare_reactant_smiles(x['true_reactants'], x['reactant_predictions']), axis=1
    )
    # classifier score
    all_smiles = df['reactant_predictions'].tolist()
    outputs, confidences = get_classifier_scores_batch(all_smiles, config, batch_size=2048)
    df['classifier_output'] = outputs
    df['classifier_confidence'] = confidences
    # round trip results
    round_trip_results = get_round_trip_results(all_smiles, config, batch_size=2048)
    df['round_trip_results'] = round_trip_results
    df['round_trip_accuracy'] = df.apply(
        lambda x: (x['product_smi'] in x['round_trip_results']) or x['topk'], axis=1
    )
    # rxn_insight info
    df['rxn_insight_info'] = df.apply(
        lambda x: get_rxn_insight_info(x['reactant_predictions']+'>>'+x['product_smi']), axis=1
    )
    df['rxn_insight_NAME'] = df['rxn_insight_info'].apply(
        lambda x: x['NAME'] if x is not None else None
    )
    df['pred_class'] = df['rxn_insight_info'].apply(
        lambda x: class_to_idx[x['CLASS']] if x is not None else None
    )
    return df

def define_single_step_model(
    config,
    conditional_starting_materials=None,
    conditional_targets=None
):
    '''
        This function defines the single step model based on the config.

        Args:
            config: the config object
            conditional_starting_material: the conditional starting material
            conditional_target: the conditional target

        Returns:
            model: the single step model
    '''
    retrosynthetic_model_dir = os.path.join(
        PROJECT_ROOT,
        'checkpoints',  
        config.single_step_model.model_dir
    )
    # TODO: add other models here based on config.single_step_model.name
    print(f'======= using model {config.single_step_model.model_type}')
    if config.single_step_model.model_type == 'retroknn':
        model = RetroKNNModel(
            use_cache=True,
            default_num_results=config.single_step_model.default_num_results,
        )
    elif config.single_step_model.model_type == 'rootaligned_original':
        model = RootAlignedModel(
            use_cache=True,
            num_augmentations=config.single_step_model.num_augmentations,
            default_num_results=config.single_step_model.default_num_results, # 10
            model_dir=retrosynthetic_model_dir
        )
    elif config.single_step_model.model_type == 'rootaligned':
        # TODO: remove conditional starting material and target from here, 
        # add to the forward call of the model
        model = RootAlignedFixedModel(
                    use_cache=True,
                    num_augmentations=config.single_step_model.num_augmentations,
                    default_num_results=config.single_step_model.default_num_results, # 10
                    model_dir=retrosynthetic_model_dir,
                    config=config,
                    conditional_starting_materials=conditional_starting_materials,
                    conditional_targets=conditional_targets
                )
    elif config.single_step_model.model_type == 'neuralsym':
        model_path = os.path.join(PROJECT_ROOT,
                                    'checkpoints',
                                    config.single_step_model.model_dir)
        templates_path = os.path.join(PROJECT_ROOT,
                                        'data',
                                        'desp_data',
                                        'idx2template_retro.json')
        model = NeuralSymPredictor(use_cache=True,
                                   default_num_results=config.single_step_model.default_num_results)
        model.setup(model_path, templates_path)
    elif config.single_step_model.model_type=='localretro':
        model = LocalRetroModel(use_cache=True, default_num_results=config.single_step_model.default_num_results)
    elif config.single_step_model.model_type=='graph2edits':
        model = Graph2EditsModel(use_cache=True, default_num_results=config.single_step_model.default_num_results)
    elif config.single_step_model.model_type=='megan':
        model = MEGANModel(use_cache=True, default_num_results=config.single_step_model.default_num_results)
    elif config.single_step_model.model_type=='mhnreact':
        model = MHNreactModel(use_cache=True, default_num_results=config.single_step_model.default_num_results)
    elif config.single_step_model.model_type=='gln':
        model = GLNModel(use_cache=True, default_num_results=config.single_step_model.default_num_results)
    elif config.single_step_model.model_type=='chemformer':
        model = ChemformerModel(use_cache=True, default_num_results=config.single_step_model.default_num_results)
    else:
        raise ValueError(f'Invalid model name: {config.single_step_model.model_type}')
    return model

def get_results_for_one_batch(
    model,
    config,
    reaction_batch,
    conditional_starting_material=None,
    conditional_target=None
):
    '''
        This function gets the results for one batch of reactions.

        Args:
            config: the config object
            reaction_batch: the batch of reactions
            conditional_starting_material: the conditional starting material
            conditional_target: the conditional target

        Returns:
            results_smiles: the results as SMILES strings
    '''
    mols = [Molecule(reaction.product) for reaction in reaction_batch]
    reaction_types = torch.tensor([reaction.class_idx for reaction in reaction_batch])
    conditional_starting_materials = [reaction.conditional_starting_material for reaction in reaction_batch]
    conditional_targets = [reaction.conditional_target for reaction in reaction_batch]
    reaction_types = reaction_types.to(device)
    print(f'default num results: {config.single_step_model.default_num_results}')
    if config.single_step_model.model_type == 'rootaligned':
        results = model(
            mols,
            num_results=config.single_step_model.default_num_results,
            reaction_types=reaction_types,
            conditional_starting_materials=conditional_starting_materials,
            conditional_targets=conditional_targets
        )
    else:
        results = model(
            mols,
            num_results=config.single_step_model.default_num_results
        )
    results_smiles = turn_results_to_mol_smiles(results)
    return results_smiles

def get_results_and_evaluate_for_one_molecule(
    config,
    conditional_starting_material=None,
    conditional_target=None
):
    '''
        This function runs the single step evaluation for one molecule.
    '''
    config.classifier_guidance.target_class_index = config.single_step_evaluation.rxn_class
    target_property = config.single_step_evaluation.rxn_class
    reactant_predictions = get_retrosynthetic_results(
        config,
        config.single_step_evaluation.product_smi,
        conditional_starting_material=conditional_starting_material,
        conditional_target=conditional_target
    )
    reactant_predictions = reactant_predictions[0]
    print(f'======= device: {device}')
    torch.cuda.reset_peak_memory_stats(device=device)
    start_time = time.time()
    df = single_step_evaluation_of_one_molecule(
        config=config,
        reactant_predictions=reactant_predictions,
        product_smi=config.single_step_evaluation.product_smi,
        true_reactants=config.single_step_evaluation.true_reactants,
        target_property=target_property
    )
    print(f'======= evaluation time: {time.time() - start_time} seconds')
    peak_mb = torch.cuda.max_memory_allocated(device=device) / 1024**2
    total_mb = torch.cuda.get_device_properties(0).total_memory / 1024**2
    print(f"\n=== PEAK MEMORY FOR ONE PRODUCT ===")
    print(f"Peak: {peak_mb:.0f}MB / {total_mb:.0f}MB ({100*peak_mb/total_mb:.1f}%)")
    print(f"Estimated batch size: {int(0.8 * total_mb / peak_mb)} products")
    return df

def single_step_evaluation_of_one_molecule(
    config,
    reactant_predictions,
    product_smi,
    true_reactants,
    target_property
):
    '''
        This function evaluates the single step evaluation for one molecule.

        config: the config object
        reactant_predictions: a list of reactant predictions
        product_smi: the product SMILES string
        true_reactants: the true reactants SMILES string
    '''
    # evaluate with classifier score
    # remove dative bonds
    # TODO: remove this side effect later
    #if not config.classifier_guidance.as_regression:
    path = os.path.join(
        PROJECT_ROOT,
        'data', 
        'desp_data', 
        'canon_building_block_mol2idx_no_isotope.json'
    )
    bbs = json.load(open(path, 'r', encoding='utf-8'))
    true_tanimoto_to_starting_material = get_tanimoto(
        true_reactants,
        config.single_step_evaluation.original_starting_material
    )
    if config.single_step_evaluation.original_target:
        true_tanimoto_to_target = get_tanimoto(
            true_reactants,
            config.single_step_evaluation.original_target
    )
    else:
        true_tanimoto_to_target = -1
    config.classifier_guidance.target_class_index = target_property
    reactant_predictions = remove_dative_bonds(reactant_predictions)
    if len(reactant_predictions)==0:
        print(f'Found no reactant predictions for product {product_smi}')
        df_dict = {}
    else:
        if config.single_step_evaluation.compute_classifier_score:
            output, confidence = get_classifier_score(reactant_predictions, config)
        else:
            output = None
            confidence = None
        print(f'classifier score: {output}, confidence: {confidence}')
        results_as_rxn_smiles = [
            reactant_prediction + '>>' + product_smi
            for reactant_prediction in reactant_predictions
        ]
        rxn_insight_info = [get_rxn_insight_info(result) for result in results_as_rxn_smiles]
        print(f'rxn insight info: {rxn_insight_info}')
        # update df with rxn insight info
        # NOTE: {1: 0} because we're checking each reactant individually.
        if true_reactants:
            # automatically returns a dict for each reactant prediction (a list of dicts)
            topk = [
                compute_topk_accuracy([pred], true_reactants, topk={1: 0})
                for pred in reactant_predictions
            ]
        else:
            topk = [-1]*len(reactant_predictions)
        # adapt the code from rsmiles
        round_trip_results = get_round_trip_results(
            reactant_predictions,
            config
        )
        # NOTE: {1: 0, 3: 0, 5: 0, 10: 0} because we get a list of potential products from the round trip model
        # round_trip_results is a list of round_trip predictions (i.e. products) for each reactant prediction
        round_trip_accuracy = [
            compute_topk_accuracy(
                round_trip_result,
                product_smi, topk={1: 0, 3: 0, 5: 0, 10: 0}
            )
            for round_trip_result in round_trip_results
        ]
        print(f'round trip accuracy: {round_trip_accuracy}')
        # add all results to a df and save in a file
        df_dict = {}
        df_dict['product_smi'] = [product_smi]*len(reactant_predictions)
        if true_reactants:
            df_dict['true_reactants'] = [true_reactants]*len(reactant_predictions)
        else:
            df_dict['true_reactants'] = [-1]*len(reactant_predictions)
        df_dict['true_class'] = [target_property]*len(reactant_predictions)
        df_dict['reactant_predictions'] = reactant_predictions
        df_dict['classifier_property'] = [
            config.classifier_guidance.property
        ]*len(reactant_predictions)
        if confidence is not None:
            df_dict['classifier_score'] = confidence.tolist()
        else:
            df_dict['classifier_score'] = [-1]*len(reactant_predictions)
        if output is not None:
            df_dict['classifier_output'] = output.tolist()
        else:
            df_dict['classifier_output'] = [-1]*len(reactant_predictions)
        df_dict['round_trip_results'] = round_trip_results
        if rxn_insight_info is not None:
            df_dict['rxn_insight_NAME'] = [
                info['NAME']
                if info is not None
                else ''
                for info in rxn_insight_info
            ]
            df_dict['pred_class'] = [
                class_to_idx[info['CLASS']]
                if info is not None
                else -1
                for info in rxn_insight_info
            ]
        else:
            df_dict['rxn_insight_NAME'] = [-1]*len(reactant_predictions)
            df_dict['pred_class'] = [-1]*len(reactant_predictions)
        df_dict['true_tanimoto_to_target'] = [true_tanimoto_to_target]*len(reactant_predictions)
        if config.single_step_evaluation.original_target:
            df_dict['pred_tanimoto_to_target'] = [
                get_tanimoto(
                    reactant_prediction,
                    config.single_step_evaluation.original_target
                )
                for reactant_prediction in reactant_predictions
            ]
        else:
            df_dict['pred_tanimoto_to_target'] = [-1]*len(reactant_predictions)
        df_dict['true_tanimoto_to_starting_material'] = [
            true_tanimoto_to_starting_material
        ]*len(reactant_predictions)
        df_dict['pred_tanimoto_to_starting_material'] = [
            get_tanimoto(
                reactant_prediction,
                config.single_step_evaluation.original_starting_material
            )
            for reactant_prediction in reactant_predictions
        ]
        df_dict['topk_detailed'] = topk
        df_dict['topk'] = [
            compare_reactant_smiles(true_reactants, reactant_prediction)
            for reactant_prediction in reactant_predictions
        ]
        df_dict['round_trip_accuracy_detailed'] = round_trip_accuracy
        df_dict['round_trip_accuracy'] = [
            product_smi in round_trip_results_for_one_pred_reactant
            for round_trip_results_for_one_pred_reactant in round_trip_results
        ]
        df_dict['rxn_insight_info'] = rxn_insight_info
        df_dict['all_pred_reactants_are_bbs'] = [
            all(m in bbs for m in reactant_prediction.split('.'))
            for reactant_prediction in reactant_predictions
        ]
    return df_dict

def load_experiment_results_old(project_root: str, experiment_dir: str, experiment_subdir: str, 
                          experiment_filters: Dict = None) -> Dict[str, pd.DataFrame]:
    """
    Load all experiment results matching the given filters.
    
    Args:
        project_root: Root directory of the project
        experiment_subdir: Subdirectory containing experiments
        experiment_filters: Dict with keys like 'steered', 'guidance', 'length', 'not_guidance', etc.
    
    Returns:
        Dict mapping experiment names to their combined DataFrames
    """
    experiment_dir = os.path.join(project_root, experiment_dir, experiment_subdir)
    experiment_names = []
    
    # Apply filters to find matching experiments
    for f in os.listdir(experiment_dir):
        if experiment_filters is None:
            experiment_names.append(f)
            continue
            
        # Check all filter conditions
        matches = True
        for key, value in experiment_filters.items():
            if key == 'steered' and value not in f:
                matches = False
            elif key == 'guidance' and value is not None and value not in f:
                matches = False
            elif key == 'not_guidance' and value is not None and value in f:
                matches = False
            elif key == 'length' and value is not None and value not in f:
                matches = False
            elif key == 'not_length' and value is not None and value in f:
                matches = False
            elif key == 'experiment_regex' and value is not None:
                if not re.search(value, f):
                    matches = False
            elif key == 'experiment_prefix' and value is not None and value not in f:
                matches = False
        
        if matches:
            experiment_names.append(f)
    
    # Load results for each experiment
    results = {}
    for exp_name in experiment_names:
        exp_dir = os.path.join(experiment_dir, exp_name)
        if not os.path.isdir(exp_dir):
            continue
            
        # Load and combine all CSV files in the experiment directory
        files = [f for f in os.listdir(exp_dir) if f.endswith('.csv')]
        if files:
            files = sorted(files, key=lambda x: int(x.split('_start')[-1].split('_end')[0]) if '_start' in x else 0)
            dfs = [pd.read_csv(os.path.join(exp_dir, file)) for file in files]
            results[exp_name] = pd.concat(dfs, ignore_index=True)
    
    return results

def load_experiment_results(
    project_root: str, 
    experiment_dir: str, 
    experiment_group: str, 
    experiment_filters: Dict = None,
    reaction_steps: List[int] = None,
    experiment_subdir: str = ''
) -> Dict[str, Dict[int, pd.DataFrame]]:
    """
    Load all experiment results matching the given filters from the new hierarchical structure.
    
    Args:
        project_root: Root directory of the project
        experiment_group: Top-level experiment group (e.g., 'no_guidance', 'reaction_type_guidance')
        experiment_filters: Dict with keys like 'guidance_scale', 'min_length', 'renorm', 'time', etc.
        reaction_steps: List of reaction steps to load (if None, loads all available steps)
    
    Returns:
        Dict mapping experiment_params -> reaction_step -> combined DataFrame
        e.g., {'no_guidance_time123': {0: df0, 1: df1}, 'guidance0.5_length10_time456': {0: df0, 1: df1}}
    """
    experiment_dir = os.path.join(project_root, experiment_dir, experiment_group)
    
    if not os.path.exists(experiment_dir):
        print(f"Warning: Experiment directory {experiment_dir} does not exist")
        return {}
    
    # Find all experiment parameter directories
    experiment_params = []
    for f in os.listdir(experiment_dir):
        if not os.path.isdir(os.path.join(experiment_dir, f)):
            continue
            
        if experiment_filters is None:
            experiment_params.append(f)
            continue
        
        # Check all filter conditions
        matches = True
        for key, value in experiment_filters.items():
            if key == 'guidance_scale' and value is not None:
                if f'guidance{value}' not in f and f != 'no_guidance':
                    matches = False
            elif key == 'min_length' and value is not None:
                if f'length{value}' not in f:
                    matches = False
            elif key == 'renorm' and value is not None:
                if f'renorm{value}' not in f:
                    matches = False
            elif key == 'time_stamp' and value is not None and not value in f:
                matches = False
            elif key == 'time_regex' and value is not None:
                if not re.search(f'time{value}', f):
                    matches = False
            elif key == 'experiment_regex' and value is not None:
                if not re.search(value, f):
                    matches = False
            elif key == 'contains' and value is not None:
                if value not in f:
                    matches = False
            elif key == 'not_contains' and value is not None:
                if value in f:
                    matches = False
                    
        if matches:
            experiment_params.append(f)
    
    # Load results for each experiment parameter set
    print('-'*100)
    print(f'Loading from {experiment_params}')
    for exp_params in experiment_params:
        print(f'Loading results for {exp_params}')
    print('-'*100)
    results = {}
    for exp_params in experiment_params:
        exp_params_dir = os.path.join(experiment_dir, exp_params, experiment_subdir)
        #print(f'Loading results for {os.listdir(exp_params_dir)}')
        df = load_single_step_results(exp_params_dir)
        results[exp_params] = df
        #results[exp_params][reaction_step] = pd.concat(dfs, ignore_index=True)

    return results

def load_experiment_results_old(project_root: str, experiment_dir: str, experiment_group: str, 
                          experiment_filters: Dict = None,
                          reaction_steps: List[int] = None) -> Dict[str, Dict[int, pd.DataFrame]]:
    """
    Load all experiment results matching the given filters from the new hierarchical structure.
    
    Args:
        project_root: Root directory of the project
        experiment_group: Top-level experiment group (e.g., 'no_guidance', 'reaction_type_guidance')
        experiment_filters: Dict with keys like 'guidance_scale', 'min_length', 'renorm', 'time', etc.
        reaction_steps: List of reaction steps to load (if None, loads all available steps)
    
    Returns:
        Dict mapping experiment_params -> reaction_step -> combined DataFrame
        e.g., {'no_guidance_time123': {0: df0, 1: df1}, 'guidance0.5_length10_time456': {0: df0, 1: df1}}
    """
    experiment_dir = os.path.join(project_root, experiment_dir, experiment_group)
    
    if not os.path.exists(experiment_dir):
        print(f"Warning: Experiment directory {experiment_dir} does not exist")
        return {}
    
    # Find all experiment parameter directories
    experiment_params = []
    for f in os.listdir(experiment_dir):
        if not os.path.isdir(os.path.join(experiment_dir, f)):
            continue
            
        if experiment_filters is None:
            experiment_params.append(f)
            continue
        
        # Check all filter conditions
        matches = True
        for key, value in experiment_filters.items():
            if key == 'guidance_scale' and value is not None:
                if f'guidance{value}' not in f and f != 'no_guidance':
                    matches = False
            elif key == 'min_length' and value is not None:
                if f'length{value}' not in f:
                    matches = False
            elif key == 'renorm' and value is not None:
                if f'renorm{value}' not in f:
                    matches = False
            elif key == 'time_stamp' and value is not None and not value in f:
                matches = False
            elif key == 'time_regex' and value is not None:
                if not re.search(f'time{value}', f):
                    matches = False
            elif key == 'experiment_regex' and value is not None:
                if not re.search(value, f):
                    matches = False
            elif key == 'contains' and value is not None:
                if value not in f:
                    matches = False
            elif key == 'not_contains' and value is not None:
                if value in f:
                    matches = False
                    
        if matches:
            experiment_params.append(f)
    
    # Load results for each experiment parameter set
    results = {}
    for exp_params in experiment_params:
        exp_params_dir = os.path.join(experiment_dir, exp_params)
        
        # Find all reaction step directories
        reaction_dirs = sorted([d for d in os.listdir(exp_params_dir) 
                        if os.path.isdir(os.path.join(exp_params_dir, d)) and d.startswith('evaluate')],\
                        key=lambda x: int(x.replace('reaction', '')))
        
        if not reaction_dirs:
            continue
            
        results[exp_params] = {}
        
        for reaction_dir in reaction_dirs:
            try:
                reaction_step = int(reaction_dir.replace('reaction', ''))
            except ValueError:
                print(f"Warning: Could not parse reaction step from {reaction_dir}")
                continue
                
            if reaction_steps is not None and reaction_step not in reaction_steps:
                continue
            
            reaction_path = os.path.join(exp_params_dir, reaction_dir)
            
            # Load and combine all CSV files in this reaction directory
            csv_files = [f for f in os.listdir(reaction_path) if f.endswith('.csv')]
            
            if not csv_files:
                print(f"Warning: No CSV files found in {reaction_path}")
                continue
                
            # Sort files by start index for consistent ordering
            csv_files = sorted(csv_files, key=lambda x: int(x.split('start')[-1].split('_end')[0]) if 'start' in x else 0)
            
            dfs = []
            for file in csv_files:
                file_path = os.path.join(reaction_path, file)
                try:
                    df = pd.read_csv(file_path)
                    dfs.append(df)
                except Exception as e:
                    print(f"Warning: Could not read {file_path}: {e}")
                    continue
            
            if dfs:
                results[exp_params][reaction_step] = pd.concat(dfs, ignore_index=True)
    
    return results

def preprocess_results_df(df: pd.DataFrame, classifier_checkpoint=None, true_routes=None) -> pd.DataFrame:
    """
    Preprocess a results DataFrame by adding derived columns.
    
    Args:
        df: Results DataFrame
        classifier_checkpoint: Optional checkpoint for denormalizing classifier outputs
    
    Returns:
        Preprocessed DataFrame
    """
    df = df.copy()
    
    # Parse classifier outputs if available
    # if 'classifier_output' in df.columns and classifier_checkpoint is not None:
    #     df['classifier_output_raw'] = df['classifier_output'].apply(
    #         lambda x: eval(x)[0] if isinstance(x, str) else x
    #     )
    #     df['classifier_output'] = (df['classifier_output_raw'] * 
    #                              classifier_checkpoint['target_std'].item() + 
    #                              classifier_checkpoint['target_mean'].item())
        
    #     # Add classifier ranks within each product group
    #     df['classifier_rank'] = df.groupby('product_smi')['classifier_output'].rank(
    #         method='first', ascending=False
    #     ).astype(int)
    if 'rxn_insight_info' in df.columns:
        df['rxn_insight_NAME'] = df['rxn_insight_info'].apply(lambda x: eval(x)['NAME'] if pd.notna(x) and x is not None else None)

    # Parse evaluation columns
    if 'round_trip_accuracy' in df.columns:
        df['round_trip_accuracy'] = df[['round_trip_results', 'product_smi']].apply(
            lambda x: x['product_smi'] in eval(x['round_trip_results']),
            axis=1
        )
        # df['round_trip_accuracy'] = df['round_trip_accuracy'].apply(
        #     lambda x: eval(x)[10]==1 if isinstance(x, str) else x
        # )
    
    # Add exact match column (you'll need to implement compare_reactant_smiles)
    if 'reactant_predictions' in df.columns and 'true_reactants' in df.columns:
        # Placeholder - replace with your actual comparison function
        df['topk'] = df.apply(
            lambda row: compare_reactant_smiles(row['reactant_predictions'], row['true_reactants']), 
            axis=1
        )
    
    # Combined correctness metric
    if 'topk' in df.columns and 'round_trip_accuracy' in df.columns:
        df['is_round_trip_correct'] = df['topk'] | df['round_trip_accuracy']

    #df['is_correct'] = df['topk']

    # compute tanimoto to starting material
    # TODO: this should move to evaluation script
    if 'conditional_starting_material' in df.columns:
        if df['conditional_starting_material'].notna().any():
            # remove both possible separators . and <unk>, then remove end token </s>
            df['conditional_starting_material_smi'] = df['conditional_starting_material'].apply(lambda x: x.split('<unk>')[-1].split('.')[-1].split('</s>')[0])
            df['pred_tanimoto_to_starting_material'] = df.apply(lambda x: get_tanimoto(x['conditional_starting_material_smi'], x['reactant_predictions']), axis=1)
            df['true_tanimoto_to_starting_material'] = df.apply(lambda x: get_tanimoto(x['conditional_starting_material_smi'], x['true_reactants']), axis=1)
        
    # compute true class
    # TODO: this should move to evaluation script
    if df['true_class'].isna().any():
        # recompute true class
        df['true_class'] = df.apply(lambda x: get_rxn_insight_info([x['true_reactants'] + '>>' + x['product_smi']]), axis=1)

    return df

def calculate_sample_level_accuracy(results: Dict[str, pd.DataFrame]) -> Dict[str, Dict[str, float]]:
    """
    Calculate accuracy across all individual samples/predictions (not just per product).
    
    Args:
        results: Dict mapping experiment names to DataFrames
    
    Returns:
        Dict with experiment names and their sample-level accuracy metrics
    """
    accuracy_results = {}
    
    for exp_name, df in results.items():
        metrics = {}
        total_samples = len(df)
        
        # Exact match accuracy (sample level)
        if 'topk' in df.columns:
            exact_matches = df['topk'].sum()
            metrics['sample_exact_match_accuracy'] = exact_matches / total_samples
        
        # Round trip accuracy (sample level)
        if 'round_trip_accuracy' in df.columns:
            round_trip_matches = df['round_trip_accuracy'].sum()
            metrics['sample_round_trip_accuracy'] = round_trip_matches / total_samples
        
        # Class accuracy (sample level)
        if 'pred_class' in df.columns and 'true_class' in df.columns:
            class_matches = (df['pred_class'] == df['true_class']).sum()
            metrics['sample_class_accuracy'] = class_matches / total_samples
        
        # Name accuracy (sample level)
        if 'rxn_insight_NAME' in df.columns:
            name_matches = ((df['rxn_insight_NAME'] != '') & 
                          (df['rxn_insight_NAME'] != 'OtherReaction')).sum()
            metrics['sample_name_accuracy'] = name_matches / total_samples
        
        # MSE metrics (already sample level by nature)
        if 'pred_tanimoto_to_starting_material' in df.columns and 'true_tanimoto_to_starting_material' in df.columns:
            valid_mask = df[['pred_tanimoto_to_starting_material', 'true_tanimoto_to_starting_material']].notna().all(axis=1)
            if valid_mask.sum() > 0:
                valid_df = df[valid_mask]
                mse = ((valid_df['pred_tanimoto_to_starting_material'] - 
                       valid_df['true_tanimoto_to_starting_material']) ** 2).mean()
                metrics['sample_tanimoto_to_starting_material_mse'] = mse
        
        if 'pred_tanimoto_to_target' in df.columns and 'true_tanimoto_to_target' in df.columns:
            valid_mask = df[['pred_tanimoto_to_target', 'true_tanimoto_to_target']].notna().all(axis=1)
            if valid_mask.sum() > 0:
                valid_df = df[valid_mask]
                mse = ((valid_df['pred_tanimoto_to_target'] - 
                       valid_df['true_tanimoto_to_target']) ** 2).mean()
                metrics['sample_tanimoto_to_target_mse'] = mse
        
        # Add sample count for reference
        metrics['total_samples'] = total_samples
        
        accuracy_results[exp_name] = metrics
    
    return accuracy_results

def calculate_exact_match_accuracy(results: Dict[str, pd.DataFrame]) -> Dict[str, Dict[str, float]]:
    """
    Calculate exact match accuracy across all products for each experiment.
    
    Args:
        results: Dict mapping experiment names to DataFrames
    
    Returns:
        Dict with experiment names and their accuracy metrics
    """
    accuracy_results = {}
    
    for exp_name, df in results.items():
        # Get unique products
        unique_products = df['product_smi'].unique()
        
        metrics = {}
        
        # Exact match accuracy
        products_with_exact_match = 0
        for product in unique_products:
            product_df = df[df['product_smi'] == product]
            if 'topk' in product_df.columns and product_df['topk'].any():
                products_with_exact_match += 1
        metrics['exact_match_accuracy'] = products_with_exact_match / len(unique_products)
        
        # Round trip accuracy
        products_with_round_trip = 0
        if 'round_trip_accuracy' in df.columns:
            for product in unique_products:
                product_df = df[df['product_smi'] == product]
                if product_df['round_trip_accuracy'].any():
                    products_with_round_trip += 1
            metrics['round_trip_accuracy'] = products_with_round_trip / len(unique_products)
        
        # Combined accuracy
        # if 'is_correct' in df.columns:
        #     products_with_any_correct = 0
        #     for product in unique_products:
        #         product_df = df[df['product_smi'] == product]
        #         if product_df['is_correct'].any():
        #             products_with_any_correct += 1
        #     metrics['combined_accuracy'] = products_with_any_correct / len(unique_products)
        
        # Additional metrics if columns exist
        if 'pred_class' in df.columns:
            # Name accuracy - assuming you want to check if predicted class matches true class
            products_with_correct_class= 0
            for product in unique_products:
                product_df = df[df['product_smi'] == product]
                if 'true_class' in df.columns:
                    correct_class = (product_df['pred_class'] == product_df['true_class']).any()
                    if correct_class:
                        products_with_correct_class += 1
            metrics['class_accuracy'] = products_with_correct_class / len(unique_products)
        
        if 'rxn_insight_NAME' in df.columns:
            products_with_correct_name = 0
            for product in unique_products:
                product_df = df[df['product_smi'] == product]
                correct_name = ((product_df['rxn_insight_NAME'] != '') & (product_df['rxn_insight_NAME'] != 'OtherReaction')).any()
                if correct_name:
                    products_with_correct_name += 1
            metrics['name_accuracy'] = products_with_correct_name / len(unique_products)
        
        # Classifier prediction accuracy/MSE
        if 'pred_tanimoto_to_starting_material' in df.columns and 'true_tanimoto_to_starting_material' in df.columns:
            # This depends on whether true_class is categorical or continuous
            # For MSE (assuming continuous):
            mse_per_product = []
            for product in unique_products:
                product_df = df[df['product_smi'] == product]
                if len(product_df) > 0:
                    # Take the best prediction (rank 1) or mean, depending on your needs
                    pred = product_df['pred_tanimoto_to_starting_material'].iloc[0]  # or .mean()
                    true_val = product_df['true_tanimoto_to_starting_material'].iloc[0]
                    mse_per_product.append((pred - true_val) ** 2)
            
            if mse_per_product:
                metrics['tanimoto_to_starting_material_mse'] = np.mean(mse_per_product)

        
        if 'pred_tanimoto_to_target' in df.columns and 'true_tanimoto_to_target' in df.columns:
            # This depends on whether true_class is categorical or continuous
            # For MSE (assuming continuous):
            mse_per_product = []
            for product in unique_products:
                product_df = df[df['product_smi'] == product]
                if len(product_df) > 0:
                    # Take the best prediction (rank 1) or mean, depending on your needs
                    pred = product_df['pred_tanimoto_to_target'].iloc[0]  # or .mean()
                    true_val = product_df['true_tanimoto_to_target'].iloc[0]
                    mse_per_product.append((pred - true_val) ** 2)
            
            if mse_per_product:
                metrics['tanimoto_to_target_mse'] = np.mean(mse_per_product)
        
        accuracy_results[exp_name] = metrics
    
    return accuracy_results

def calculate_mixed_param_route_completion(
    results: Dict[str, pd.DataFrame], 
    true_routes: List[Dict],
    use_starting_material: bool = False,
    max_steps: int = 15,
    starting_material_key: str = 'route_most_similar_starting_material'
) -> Dict:
    """
    Calculate route completion allowing mixed parameters across steps.
    Each step can use the best result from any parameter combination.
    
    Args:
        results: Dict mapping experiment names to DataFrames
        true_routes: List of ground truth routes  
        use_starting_material: Whether to check for starting material constraints
        max_steps: Maximum number of steps to consider per route
    
    Returns:
        Dict with route completion statistics and parameter tracking
    """
    # Group by reaction step first
    #step_grouped = group_experiments_by_reaction_step(results)
    
    route_stats = {
        'total_routes': 0,
        'fully_completed_routes': 0,
        'completion_rate': 0.0,
        'routes_with_starting_material': 0 if use_starting_material else None,
        #'available_steps': sorted(step_grouped.keys()),
        'route_details': [],  # Store per-route details
        'step_param_usage': defaultdict(lambda: defaultdict(int))  # Track which params work for each step
    }
    
    #print(f"Available reaction steps: {sorted(step_grouped.keys())}")
    for route_idx, route in enumerate(true_routes):
        route_stats['total_routes'] += 1
        steps_solved = 0
        route_has_starting_material = False
        route_detail = {
            'route_idx': route_idx,
            'main_target': route.get('main_target', ''),
            'total_steps': min(len(route['route']), max_steps),
            'steps_solved': 0,
            'step_results': []  # Store which param worked for each step
        }
        
        for reaction_idx, reaction in enumerate(route['route']):
            if reaction_idx > max_steps:
                break
                
            true_product = reaction.split('>>')[0]
            true_reactants = reaction.split('>>')[1]
            starting_material = route[starting_material_key] # NOTE: assume we use the latest route file now
            step_result = {
                'reaction_idx': reaction_idx,
                'true_product': true_product,
                'true_reactants': true_reactants,
                'true_distance_to_starting_material': get_tanimoto(starting_material, true_reactants),
                'solved': False,
                'working_params': [],
                'has_starting_material': [],
                'num_samples': [],
                'min_pred_tanimoto_to_starting_material': [],
                'max_pred_tanimoto_to_starting_material': [],
                'mean_pred_tanimoto_to_starting_material': [],
                'true_tanimoto_to_starting_material': [],
                'avg_true_class_matches': []
            }
            
            # Check all parameter combinations for this step
            reactions_in_one_step_per_combo = []
            for param_combo, df in results.items():
                product_results = df[df['product_smi'] == true_product]
                #print(f'true_product: {true_product}')
                #print(f'product_results: {product_results}')
                #if use_starting_material and len(product_results) > 0:
                    # if pd.isna(product_results['pred_tanimoto_to_starting_material'].min()):
                    #     print(f'NA value for pred_tanimoto_to_starting_material at step_idx: {reaction_idx}, route_idx: {route_idx}, param_combo: {param_combo}')
                    #     print(f"tanimoto values: {product_results['pred_tanimoto_to_starting_material'].values}")
                    #     print(f"product_results: {product_results}")

                if len(product_results) > 0 and 'topk' in product_results.columns:
                    step_result['working_params'].append(param_combo)
                    # only compute this in combos with the correct reactant
                    step_result['min_pred_tanimoto_to_starting_material'].append(product_results['pred_tanimoto_to_starting_material'].min())
                    step_result['max_pred_tanimoto_to_starting_material'].append(product_results['pred_tanimoto_to_starting_material'].max())
                    step_result['mean_pred_tanimoto_to_starting_material'].append(product_results['pred_tanimoto_to_starting_material'].mean())
                    step_result['true_tanimoto_to_starting_material'].append(get_tanimoto(starting_material, true_reactants))
                    step_result['has_starting_material'].append(product_results['reactant_predictions'].apply(
                        lambda x: starting_material in x.split('.')
                    ).any())
                    step_result['avg_true_class_matches'].append((product_results['true_class'] == product_results['pred_class']).mean())
                    step_result['num_samples'].append(len(product_results))
                    reactions_in_one_step_per_combo.append(product_results['reactant_predictions'].tolist())
                    # Track parameter usage statistics
                    route_stats['step_param_usage'][reaction_idx][param_combo] += 1

                    if product_results['topk'].any():
                        step_result['solved'] = True
                        # # Check starting material if needed
                        
                        #     # starting_material = pick_starting_material(
                        #     #     route['main_target'], route['starting_material']
                        #     # )
                        # if use_starting_material:
                        #     correct_results = product_results[product_results['topk']]
                        #     for _, row in correct_results.iterrows():
                        #         if starting_material in row['reactant_predictions']:
                        #             #step_result['has_starting_material'] = True
                        #             route_has_starting_material = True
        
            # Compute pairwise Jaccard similarities between parameter combinations
            if len(reactions_in_one_step_per_combo) > 1:
                param_combos = list(step_result['working_params'])
                jaccard_matrix = {}
                
                for (idx1, reactions1), (idx2, reactions2) in combinations(enumerate(reactions_in_one_step_per_combo), 2):
                    combo1 = param_combos[idx1]
                    combo2 = param_combos[idx2]
                    similarity = jaccard_similarity(reactions1, reactions2)
                    jaccard_matrix[(combo1, combo2)] = similarity
                
                # Store in step_result if needed
                step_result['jaccard_similarities'] = jaccard_matrix
                
                # Optional: compute average similarity for this step
                if jaccard_matrix:
                    step_result['mean_jaccard_similarity'] = sum(jaccard_matrix.values()) / len(jaccard_matrix)

            route_detail['step_results'].append(step_result)
            if step_result['solved']:
                steps_solved += 1
        
        route_detail['steps_solved'] = steps_solved
        route_stats['route_details'].append(route_detail)
        
        # Check if route is fully completed
        total_steps = min(len(route['route']), max_steps)
        if steps_solved == total_steps:
            route_stats['fully_completed_routes'] += 1
            if route_has_starting_material:
                route_stats['routes_with_starting_material'] += 1
    
    # Calculate completion rate
    if route_stats['total_routes'] > 0:
        route_stats['completion_rate'] = (route_stats['fully_completed_routes'] / 
                                        route_stats['total_routes'])
    
    return route_stats

def compute_average_topk_and_coverage(df, topk_key='new_topk', coverage_key='round_trip_accuracy'):
    '''
        Compute the average topk and roundtrip accuracy for the dataframe.
    '''
    num_products = df['product_smi'].nunique()
    topk = {1: 0, 3: 0, 5: 0, 10: 0, 50: 0, 100: 0}
    # Original logic for exact matches
    topk_with_rank = df.groupby('product_smi').apply(
            lambda x: pd.DataFrame({topk_key: x.reset_index(drop=True)[topk_key]==1}),
            include_groups=False
        ).reset_index()
    topk_matches_df = topk_with_rank[topk_with_rank[topk_key]]
    for k in topk:
        topk[k] = topk_matches_df[topk_matches_df['level_1']+1<=k].shape[0]/num_products

    coverage = {1: 0, 3: 0, 5: 0, 10: 0}
    # For round-trip: check if ANY top-k prediction succeeds per product
    for k in coverage:
        product_success = df.groupby('product_smi').apply(
            lambda x: (x.head(k)[coverage_key] == 1).any(),
            include_groups=False
        )
        coverage[k] = product_success.mean()
    return topk, coverage

def analyze_parameter_effectiveness(route_completion_results: Dict) -> Dict:
    """
    Analyze which parameters are most effective for each step.
    
    Args:
        route_completion_results: Results from calculate_mixed_param_route_completion
    
    Returns:
        Analysis of parameter effectiveness per step
    """
    step_param_usage = route_completion_results['step_param_usage']
    total_routes = route_completion_results['total_routes']
    
    analysis = {
        'step_analysis': {},
        'overall_param_ranking': defaultdict(int)
    }
    
    for step_idx, param_counts in step_param_usage.items():
        step_analysis = {
            'total_success_count': sum(param_counts.values()),
            'param_success_rates': {},
            'best_param': None,
            'best_success_rate': 0
        }
        
        for param, count in param_counts.items():
            success_rate = count / total_routes
            step_analysis['param_success_rates'][param] = {
                'count': count,
                'success_rate': success_rate
            }
            
            # Track overall parameter effectiveness
            analysis['overall_param_ranking'][param] += count
            
            # Track best parameter for this step
            if success_rate > step_analysis['best_success_rate']:
                step_analysis['best_success_rate'] = success_rate
                step_analysis['best_param'] = param
        
        analysis['step_analysis'][step_idx] = step_analysis
    
    return analysis

# Updated wrapper function
def calculate_route_completion_rates(results: Dict[str, pd.DataFrame], 
                                   true_routes: List[Dict],
                                   use_starting_material: bool = False,
                                   max_steps: int = 15) -> Dict[str, Dict]:
    """
    Calculate route completion with mixed parameter analysis.
    
    Returns both individual parameter results and mixed parameter results.
    """
    # Calculate mixed parameter completion (your preferred method)
    mixed_results = calculate_mixed_param_route_completion(
        results, true_routes, use_starting_material, max_steps
    )
    
    # Analyze parameter effectiveness
    param_analysis = analyze_parameter_effectiveness(mixed_results)
    
    return {
        'mixed_param_completion': mixed_results,
        'parameter_analysis': param_analysis
    }

def calculate_summary_statistics(accuracy_metrics: Dict, route_completion: Dict) -> Dict:
    """
    Calculate summary statistics across all experiments.
    
    Args:
        accuracy_metrics: Results from calculate_exact_match_accuracy
        route_completion: Results from calculate_route_completion_rates
    
    Returns:
        Dict with summary statistics
    """
    summary_stats = {
        'accuracy_summary': {},
        'route_completion_summary': {}
    }
    
    # Accuracy metrics summary
    if accuracy_metrics:
        # Collect all accuracy values across experiments
        accuracy_collections = defaultdict(list)
        
        for exp_metrics in accuracy_metrics.values():
            for metric_name, value in exp_metrics.items():
                if isinstance(value, (int, float)) and not np.isnan(value):
                    accuracy_collections[metric_name].append(value)
        
        # Calculate summary stats for each metric
        for metric_name, values in accuracy_collections.items():
            if values:
                summary_stats['accuracy_summary'][metric_name] = {
                    'mean': np.mean(values),
                    'std': np.std(values),
                    'min': np.min(values),
                    'max': np.max(values),
                    'count': len(values)
                }
    
    # Route completion summary
    if 'mixed_param_completion' in route_completion:
        mixed_results = route_completion['mixed_param_completion']
        summary_stats['route_completion_summary'] = {
            'total_routes': mixed_results['total_routes'],
            'fully_completed_routes': mixed_results['fully_completed_routes'],
            'overall_completion_rate': mixed_results['completion_rate']
        }
        
        if mixed_results['routes_with_starting_material'] is not None:
            summary_stats['route_completion_summary']['routes_with_starting_material'] = mixed_results['routes_with_starting_material']
        
        # Add step-wise statistics
        if 'route_details' in mixed_results:
            steps_solved_per_route = [route['steps_solved'] for route in mixed_results['route_details']]
            total_steps_per_route = [route['total_steps'] for route in mixed_results['route_details']]
            
            summary_stats['route_completion_summary']['steps_statistics'] = {
                'mean_steps_solved': np.mean(steps_solved_per_route),
                'std_steps_solved': np.std(steps_solved_per_route),
                'mean_total_steps': np.mean(total_steps_per_route),
                'completion_distribution': {
                    'fully_completed': sum(1 for s, t in zip(steps_solved_per_route, total_steps_per_route) if s == t),
                    'partially_completed': sum(1 for s in steps_solved_per_route if s > 0),
                    'not_completed': sum(1 for s in steps_solved_per_route if s == 0)
                }
            }
    
    else:
        # Handle case where we have individual parameter results instead of mixed
        completion_rates = []
        total_routes_list = []
        completed_routes_list = []
        
        for param_results in route_completion.values():
            if isinstance(param_results, dict) and 'completion_rate' in param_results:
                completion_rates.append(param_results['completion_rate'])
                total_routes_list.append(param_results.get('total_routes', 0))
                completed_routes_list.append(param_results.get('fully_completed_routes', 0))
        
        if completion_rates:
            summary_stats['route_completion_summary'] = {
                'mean_completion_rate': np.mean(completion_rates),
                'std_completion_rate': np.std(completion_rates),
                'min_completion_rate': np.min(completion_rates),
                'max_completion_rate': np.max(completion_rates),
                'total_experiments': len(completion_rates)
            }
    
    return summary_stats

def calculate_per_experiment_statistics(results: Dict[str, Dict[int, pd.DataFrame]]) -> Dict[str, Dict[str, float]]:
    """
    Calculate statistics for each individual experiment (parameter configuration).
    
    Args:
        results: Dict mapping experiment_params -> reaction_step -> DataFrame
    
    Returns:
        Dict mapping experiment names to their individual metrics
    """
    per_experiment_stats = {}
    
    for exp_params in results:
        # Combine all steps for this experiment
        # all_dfs = []
        # for step, df in step_data.items():
        #     df_copy = df.copy()
        #     df_copy['reaction_step'] = step
        #     all_dfs.append(df_copy)
        
        # if not all_dfs:
        #     print(f"Warning: No data found for experiment {exp_params}")
        #     continue
            
        # combined_df = pd.concat(all_dfs, ignore_index=True)
        
        # Calculate metrics for this experiment
        metrics = _calculate_per_experiment_metrics(results[exp_params])
        per_experiment_stats[exp_params] = metrics
    
    return per_experiment_stats

def _calculate_per_experiment_metrics(df: pd.DataFrame, total_samples_per_product: int = 100) -> Dict[str, float]:
    """Helper function to calculate metrics for a single experiment DataFrame"""
    metrics = {}
    total_samples = len(df)
    
    if total_samples == 0:
        return {'total_samples': 0}
    
    # Basic counts
    metrics['total_samples'] = total_samples
    # df['topk'] = df['topk'].apply(lambda x:  eval(x)[0]).astype(bool)
    # df['pred_tanimoto_to_starting_material'] = df['pred_tanimoto_to_starting_material'].apply(lambda x: eval(x)[0])
    # df['pred_tanimoto_to_target'] = df['pred_tanimoto_to_target'].apply(lambda x: eval(x)[0])
    
    # Count unique products
    if 'product_smi' in df.columns:
        unique_products = df['product_smi'].nunique()
        metrics['total_products'] = unique_products
        #metrics['avg_samples_per_product'] = total_samples / unique_products if unique_products > 0 else 0
        metrics['avg_samples_per_product'] = df.groupby('product_smi').size().mean()
        metrics['perc_samples_per_product'] = metrics['avg_samples_per_product'] / total_samples_per_product
    
    # Exact match accuracy (sample level)
    if 'topk' in df.columns:
        exact_matches = df['topk'].sum()
        metrics['sample_exact_match_accuracy'] = exact_matches / total_samples
        
        # Product-level exact match
        if 'product_smi' in df.columns:
            products_with_exact_match = df[df['topk']]['product_smi'].nunique()
            metrics['products_with_exact_match'] = products_with_exact_match
            metrics['percentage_products_with_exact_match'] = products_with_exact_match / unique_products if unique_products > 0 else 0
    
    # Class accuracy (sample level)
    if 'pred_class' in df.columns and 'true_class' in df.columns:
        class_matches = (df['pred_class'] == df['true_class']).sum()
        metrics['sample_class_accuracy'] = class_matches / total_samples
        # Product-level class accuracy
        if 'product_smi' in df.columns:
            products_with_class_correct = df[df['pred_class'] == df['true_class']]['product_smi'].nunique()
            metrics['products_with_class_correct_samples'] = products_with_class_correct
            metrics['percentage_products_with_class_correct'] = products_with_class_correct / unique_products if unique_products > 0 else 0
            # Average class-correct samples per product
            class_correct_per_product = df[df['pred_class'] == df['true_class']].groupby('product_smi').size()
            metrics['avg_class_correct_samples_per_product'] = class_correct_per_product.mean() if len(class_correct_per_product) > 0 else 0
            metrics['perc_class_correct_samples_per_product'] = metrics['avg_class_correct_samples_per_product'] / total_samples_per_product

    # pred class and round trip accuracy
    if 'pred_class' in df.columns and 'round_trip_accuracy' in df.columns:
        class_and_round_trip_correct = (df['pred_class'] == df['true_class']) & df['round_trip_accuracy']
        metrics['sample_class_and_round_trip_correct'] = class_and_round_trip_correct.sum() / total_samples
        # Product-level class and round trip accuracy
        if 'product_smi' in df.columns:
            products_with_class_and_round_trip_correct = df[class_and_round_trip_correct]['product_smi'].nunique()
            metrics['products_with_class_and_round_trip_correct_samples'] = products_with_class_and_round_trip_correct
            metrics['percentage_products_with_class_and_round_trip_correct'] = products_with_class_and_round_trip_correct / unique_products if unique_products > 0 else 0
            # Average class-and-round-trip-correct samples per product
            class_and_round_trip_correct_per_product = df[class_and_round_trip_correct].groupby('product_smi').size()
            metrics['avg_class_and_round_trip_correct_samples_per_product'] = class_and_round_trip_correct_per_product.mean() if len(class_and_round_trip_correct_per_product) > 0 else 0
            metrics['perc_class_and_round_trip_correct_samples_per_product'] = metrics['avg_class_and_round_trip_correct_samples_per_product'] / total_samples_per_product
    
    # RXN name accuracy (sample level)  
    if 'rxn_insight_NAME' in df.columns:
        name_matches = ((df['rxn_insight_NAME'] != '') &
                       (df['rxn_insight_NAME'] != 'OtherReaction')).sum()
        metrics['sample_rxn_name_accuracy'] = name_matches / total_samples
        
        # Product-level rxn name accuracy
        if 'product_smi' in df.columns:
            rxn_name_correct_mask = (df['rxn_insight_NAME'] != '') & (df['rxn_insight_NAME'] != 'OtherReaction')
            products_with_rxn_name_correct = df[rxn_name_correct_mask]['product_smi'].nunique()
            metrics['products_with_rxn_name_correct_samples'] = products_with_rxn_name_correct
            metrics['percentage_products_with_rxn_name_correct'] = products_with_rxn_name_correct / unique_products if unique_products > 0 else 0
            
            # Average rxn-name-correct samples per product
            rxn_name_correct_per_product = df[rxn_name_correct_mask].groupby('product_smi').size()
            metrics['avg_rxn_name_correct_samples_per_product'] = rxn_name_correct_per_product.mean() if len(rxn_name_correct_per_product) > 0 else 0
            metrics['perc_rxn_name_correct_samples_per_product'] = metrics['avg_rxn_name_correct_samples_per_product'] / total_samples_per_product
    
    # Round trip accuracy (sample level)
    if 'round_trip_accuracy' in df.columns:
        round_trip_matches = df['round_trip_accuracy'].sum()
        metrics['sample_round_trip_accuracy'] = round_trip_matches / total_samples
        
        # Product-level round trip accuracy
        if 'product_smi' in df.columns:
            products_with_round_trip_correct = df[df['round_trip_accuracy']]['product_smi'].nunique()
            metrics['products_with_round_trip_correct_samples'] = products_with_round_trip_correct
            metrics['percentage_products_with_round_trip_correct'] = products_with_round_trip_correct / unique_products if unique_products > 0 else 0
            
            # Average round-trip-correct samples per product
            round_trip_correct_per_product = df[df['round_trip_accuracy']].groupby('product_smi').size()
            metrics['avg_round_trip_correct_samples_per_product'] = round_trip_correct_per_product.mean() if len(round_trip_correct_per_product) > 0 else 0
            metrics['perc_round_trip_correct_samples_per_product'] = metrics['avg_round_trip_correct_samples_per_product'] / total_samples_per_product

    # Tanimoto metrics (continuous values)
    if 'pred_tanimoto_to_starting_material' in df.columns:
        valid_mask = df['pred_tanimoto_to_starting_material'].notna()
        if valid_mask.sum() > 0:
            metrics['avg_tanimoto_to_starting'] = df.loc[valid_mask, 'pred_tanimoto_to_starting_material'].mean()
            metrics['max_tanimoto_to_starting'] = df.loc[valid_mask, 'pred_tanimoto_to_starting_material'].max()
        # product level
        if 'product_smi' in df.columns:
            products_with_max_tanimoto_to_starting = df.groupby('product_smi')['pred_tanimoto_to_starting_material'].max().mean()
            products_with_min_tanimoto_to_starting = df.groupby('product_smi')['pred_tanimoto_to_starting_material'].min().mean()
            metrics['products_with_max_tanimoto_to_starting'] = products_with_max_tanimoto_to_starting
            metrics['products_with_min_tanimoto_to_starting'] = products_with_min_tanimoto_to_starting
    
    if 'pred_tanimoto_to_target' in df.columns:
        valid_mask = df['pred_tanimoto_to_target'].notna()
        if valid_mask.sum() > 0:
            metrics['avg_tanimoto_to_target'] = df.loc[valid_mask, 'pred_tanimoto_to_target'].mean()
            metrics['max_tanimoto_to_target'] = df.loc[valid_mask, 'pred_tanimoto_to_target'].max()
    
    #Add step-level breakdown if multiple steps
    if 'reaction_step' in df.columns:
        step_counts = df['reaction_step'].value_counts().sort_index()
        for step, count in step_counts.items():
            metrics[f'samples_step_{step}'] = count
    
    # compute regular topk and round trip accuracy
    topk, coverage = compute_average_topk_and_coverage(df, topk_key='topk', coverage_key='round_trip_accuracy')
    metrics['avg_topk'] = topk
    metrics['avg_coverage'] = coverage
    
    return metrics

def calculate_per_product_aggregation(results: Dict[str, Dict[int, pd.DataFrame]]) -> Dict[str, Dict]:
    """
    Aggregate results across all experiments (parameter configurations) at the product level.
    
    Args:
        results: Dict mapping experiment_params -> reaction_step -> DataFrame
    
    Returns:
        Dict with 'per_product' and 'aggregated_metrics' keys
    """
    # First, combine all data from all experiments
    all_data = []
    
    for exp_params, step_data in results.items():
        for step, df in step_data.items():
            df_copy = df.copy()
            df_copy['experiment_params'] = exp_params
            df_copy['reaction_step'] = step
            all_data.append(df_copy)
    
    if not all_data:
        return {'per_product': {}, 'aggregated_metrics': {}}
    
    combined_df = pd.concat(all_data, ignore_index=True)
    
    # Calculate per-product metrics
    per_product_results = _calculate_per_product_metrics(combined_df)
    
    # Calculate aggregated metrics from per-product results
    aggregated_metrics = _calculate_aggregated_metrics(per_product_results)
    
    return {
        'per_product': per_product_results,
        'aggregated_metrics': aggregated_metrics
    }


def _calculate_per_product_metrics(df: pd.DataFrame) -> Dict[str, Dict]:
    """Calculate metrics for each individual product across all experiments"""
    per_product_results = {}
    
    if 'product_smi' not in df.columns:
        print("Warning: 'product_smi' column not found in DataFrame")
        return per_product_results
    
    # For each (product_smi, true_reactants) pair, keep only rows from the first reaction_step
    df_deduplicated = df.groupby(['product_smi', 'true_reactants'])['reaction_step'].transform('first')
    df = df[df['reaction_step'] == df_deduplicated]
    unique_products = df[['product_smi', 'true_reactants']].apply(tuple, axis=1).unique()
    
    for product, _ in unique_products:
        product_df = df[df['product_smi'] == product].copy()
        if len(product_df) > 100:
            # sample 100 samples
            product_df = product_df.sample(n=100)
            print(f"Warning: {product} has {len(product_df)} samples")
        product_results = {}
        
        # Basic counts
        product_results['total_samples'] = len(product_df)
        product_results['all_samples'] = product_df.to_dict('records')  # Store all samples
        
        # Get list of experiment configs that contributed samples
        product_results['experiment_configs'] = sorted(product_df['experiment_params'].unique().tolist())
        
        # Exact match metrics
        if 'topk' in product_df.columns:
            exact_match_samples = product_df[product_df['topk'] == True]
            product_results['exact_match_found'] = len(exact_match_samples) > 0
            product_results['total_exact_match_samples'] = len(exact_match_samples)
            
            if len(exact_match_samples) > 0:
                product_results['configs_with_exact_match'] = sorted(exact_match_samples['experiment_params'].unique().tolist())
            else:
                product_results['configs_with_exact_match'] = []
        
        # Class accuracy metrics
        if 'pred_class' in product_df.columns and 'true_class' in product_df.columns:
            class_correct_samples = product_df[product_df['pred_class'] == product_df['true_class']]
            product_results['total_class_correct_samples'] = len(class_correct_samples)
            product_results['total_class_accuracy'] = len(class_correct_samples) / len(product_df) if len(product_df) > 0 else 0
            
            if len(class_correct_samples) > 0:
                # Best single sample class accuracy (this would be 1.0 if any sample is correct)
                product_results['best_class_accuracy'] = 1.0
                product_results['configs_with_class_correct'] = sorted(class_correct_samples['experiment_params'].unique().tolist())
            else:
                product_results['best_class_accuracy'] = 0.0
                product_results['configs_with_class_correct'] = []
        
        # RXN name metrics
        if 'rxn_insight_NAME' in product_df.columns:
            rxn_name_correct_mask = (product_df['rxn_insight_NAME'] != '') & (product_df['rxn_insight_NAME'] != 'OtherReaction')
            rxn_name_correct_samples = product_df[rxn_name_correct_mask]
            product_results['total_rxn_name_correct_samples'] = len(rxn_name_correct_samples)
            product_results['total_rxn_name_accuracy'] = len(rxn_name_correct_samples) / len(product_df) if len(product_df) > 0 else 0
            
            if len(rxn_name_correct_samples) > 0:
                product_results['best_rxn_name_accuracy'] = 1.0
                product_results['configs_with_rxn_name_correct'] = sorted(rxn_name_correct_samples['experiment_params'].unique().tolist())
            else:
                product_results['best_rxn_name_accuracy'] = 0.0
                product_results['configs_with_rxn_name_correct'] = []
        
        # Round trip metrics
        if 'round_trip_accuracy' in product_df.columns:
            round_trip_correct_samples = product_df[product_df['round_trip_accuracy'] == True]
            product_results['total_round_trip_correct_samples'] = len(round_trip_correct_samples)
            product_results['total_round_trip_accuracy'] = len(round_trip_correct_samples) / len(product_df) if len(product_df) > 0 else 0
            
            if len(round_trip_correct_samples) > 0:
                product_results['best_round_trip_accuracy'] = 1.0
                product_results['configs_with_round_trip_correct'] = sorted(round_trip_correct_samples['experiment_params'].unique().tolist())
            else:
                product_results['best_round_trip_accuracy'] = 0.0
                product_results['configs_with_round_trip_correct'] = []
        
        # Tanimoto metrics (continuous - find best values)
        if 'pred_tanimoto_to_starting_material' in product_df.columns:
            tanimoto_starting_valid = product_df['pred_tanimoto_to_starting_material'].dropna()
            if len(tanimoto_starting_valid) > 0:
                product_results['best_tanimoto_to_starting_material'] = tanimoto_starting_valid.max()
                product_results['avg_tanimoto_to_starting_material'] = tanimoto_starting_valid.mean()
                
                # Find config with best tanimoto
                best_idx = product_df['pred_tanimoto_to_starting_material'].idxmax()
                product_results['config_with_best_tanimoto_starting'] = product_df.loc[best_idx, 'experiment_params']
        
        if 'pred_tanimoto_to_target' in product_df.columns:
            tanimoto_target_valid = product_df['pred_tanimoto_to_target'].dropna()
            if len(tanimoto_target_valid) > 0:
                product_results['best_tanimoto_to_target'] = tanimoto_target_valid.max()
                product_results['avg_tanimoto_to_target'] = tanimoto_target_valid.mean()
                
                # Find config with best tanimoto
                best_idx = product_df['pred_tanimoto_to_target'].idxmax()
                product_results['config_with_best_tanimoto_target'] = product_df.loc[best_idx, 'experiment_params']
        
        per_product_results[product] = product_results
    
    return per_product_results

def _calculate_aggregated_metrics(per_product_results: Dict[str, Dict]) -> Dict[str, float]:
    """Calculate aggregated metrics from per-product results"""
    if not per_product_results:
        return {}
    
    total_products = len(per_product_results)
    aggregated = {'total_products': total_products}
    num_samples_per_product = [p.get('total_samples', 0) for p in per_product_results.values()]
    aggregated['avg_num_samples_per_product'] = sum(num_samples_per_product) / len(num_samples_per_product) if num_samples_per_product else 0
    
    # Exact match metrics
    products_with_exact_match = sum(1 for p in per_product_results.values() if p.get('exact_match_found', False))
    aggregated['products_with_exact_match'] = products_with_exact_match
    aggregated['percentage_products_with_exact_match'] = products_with_exact_match / total_products if total_products > 0 else 0
    
    # Class accuracy metrics
    products_with_class_correct = sum(1 for p in per_product_results.values() if p.get('total_class_correct_samples', 0) > 0)
    aggregated['products_with_class_correct_samples'] = products_with_class_correct
    aggregated['percentage_products_with_class_correct'] = products_with_class_correct / total_products if total_products > 0 else 0
    
    # Average class-correct samples per product
    class_correct_samples = [p.get('total_class_correct_samples', 0) for p in per_product_results.values()]
    aggregated['avg_class_correct_samples_per_product'] = sum(class_correct_samples) / len(class_correct_samples) if class_correct_samples else 0
    #class_accuracy_samples = [p.get('total_class_accuracy', 0) for p in per_product_results.values()]
    avg_class_accuracy_per_product = [correct/total for correct, total in zip(class_correct_samples, num_samples_per_product)]
    aggregated['avg_class_accuracy_per_product'] = sum(avg_class_accuracy_per_product) / len(avg_class_accuracy_per_product) if class_correct_samples else 0
    
    # RXN name metrics
    products_with_rxn_name_correct = sum(1 for p in per_product_results.values() if p.get('total_rxn_name_correct_samples', 0) > 0)
    aggregated['products_with_rxn_name_correct_samples'] = products_with_rxn_name_correct
    aggregated['percentage_products_with_rxn_name_correct'] = products_with_rxn_name_correct / total_products if total_products > 0 else 0

    # Average rxn-name-correct samples per product
    rxn_name_correct_samples = [p.get('total_rxn_name_correct_samples', 0) for p in per_product_results.values()]
    aggregated['avg_rxn_name_correct_samples_per_product'] = sum(rxn_name_correct_samples) / len(rxn_name_correct_samples) if rxn_name_correct_samples else 0
    #rxn_name_accuracy_samples = [p.get('total_rxn_name_accuracy', 0) for p in per_product_results.values()]
    avg_rxn_name_accuracy_per_product = [correct/total for correct, total in zip(rxn_name_correct_samples, num_samples_per_product)]
    aggregated['avg_rxn_name_accuracy_per_product'] = sum(avg_rxn_name_accuracy_per_product) / len(avg_rxn_name_accuracy_per_product) if rxn_name_correct_samples else 0
    #aggregated['avg_rxn_name_accuracy_per_product'] = sum(rxn_name_correct_samples) / sum(num_samples_per_product) if rxn_name_correct_samples else 0
    
    # Round trip metrics
    products_with_round_trip_correct = sum(1 for p in per_product_results.values() if p.get('total_round_trip_correct_samples', 0) > 0)
    aggregated['products_with_round_trip_correct_samples'] = products_with_round_trip_correct
    aggregated['percentage_products_with_round_trip_correct'] = products_with_round_trip_correct / total_products if total_products > 0 else 0
    
    # Average round-trip-correct samples per product
    round_trip_correct_samples = [p.get('total_round_trip_correct_samples', 0) for p in per_product_results.values()]
    aggregated['avg_round_trip_correct_samples_per_product'] = sum(round_trip_correct_samples) / len(round_trip_correct_samples) if round_trip_correct_samples else 0
    #round_trip_accuracy_samples = [p.get('total_round_trip_accuracy', 0) for p in per_product_results.values()]
    avg_round_trip_accuracy_per_product = [correct/total for correct, total in zip(round_trip_correct_samples, num_samples_per_product)]
    aggregated['avg_round_trip_accuracy_per_product'] = sum(avg_round_trip_accuracy_per_product) / len(avg_round_trip_accuracy_per_product) if round_trip_correct_samples else 0
    #aggregated['avg_round_trip_accuracy_per_product'] = sum(round_trip_correct_samples) / sum(num_samples_per_product) if round_trip_correct_samples else 0
    
    # Tanimoto metrics
    tanimoto_starting_values = [p.get('best_tanimoto_to_starting_material') for p in per_product_results.values() if p.get('best_tanimoto_to_starting_material') is not None]
    if tanimoto_starting_values:
        aggregated['avg_best_tanimoto_to_starting_per_product'] = sum(tanimoto_starting_values) / len(tanimoto_starting_values)
    
    tanimoto_target_values = [p.get('best_tanimoto_to_target') for p in per_product_results.values() if p.get('best_tanimoto_to_target') is not None]
    if tanimoto_target_values:
        aggregated['avg_best_tanimoto_to_target_per_product'] = sum(tanimoto_target_values) / len(tanimoto_target_values)
    
    return aggregated

def calculate_improvement_metrics(guided_results: Dict, baseline_results: Dict) -> Dict:
    """
    Calculate improvement metrics by comparing guided experiments to baseline.
    
    Args:
        guided_results: Results from calculate_per_product_aggregation() for guided experiments
        baseline_results: Results from calculate_per_product_aggregation() for baseline experiments
        
    Returns:
        Dict with improvement metrics and comparisons
    """
    guided_per_product = guided_results.get('per_product', {})
    baseline_per_product = baseline_results.get('per_product', {})
    
    guided_aggregated = guided_results.get('aggregated_metrics', {})
    baseline_aggregated = baseline_results.get('aggregated_metrics', {})
    
    improvement_metrics = {}
    
    # Overall comparison metrics
    improvement_metrics['comparison_summary'] = _calculate_overall_comparison(guided_aggregated, baseline_aggregated)
    
    # Per-product improvement analysis
    improvement_metrics['per_product_improvements'] = _calculate_per_product_improvements(guided_per_product, baseline_per_product)
    
    # Product-level improvement statistics
    improvement_metrics['product_improvement_stats'] = _calculate_product_improvement_stats(
        improvement_metrics['per_product_improvements']
    )
    
    return improvement_metrics


def _calculate_overall_comparison(guided_agg: Dict, baseline_agg: Dict) -> Dict:
    """Calculate overall comparison between guided and baseline results"""
    comparison = {}
    
    # Define metrics to compare
    metrics_to_compare = [
        ('avg_num_samples_per_product', 'Avg Num Samples'),
        ('percentage_products_with_exact_match', 'Exact Match'),
        ('percentage_products_with_class_correct', 'Class Correct'), 
        ('percentage_products_with_rxn_name_correct', 'RXN Name Correct'),
        ('percentage_products_with_round_trip_correct', 'Round Trip Correct'),
        ('avg_class_correct_samples_per_product', 'Avg Class Correct Samples'),
        ('avg_class_accuracy_per_product', 'Avg Class Accuracy'),
        ('avg_rxn_name_correct_samples_per_product', 'Avg RXN Name Samples'),
        ('avg_rxn_name_accuracy_per_product', 'Avg RXN Name Accuracy'),
        ('avg_round_trip_correct_samples_per_product', 'Avg Round Trip Samples'),
        ('avg_round_trip_accuracy_per_product', 'Avg Round Trip Accuracy'),
        ('avg_best_tanimoto_to_starting_per_product', 'Avg Best Tanimoto Starting'),
        ('avg_best_tanimoto_to_target_per_product', 'Avg Best Tanimoto Target')
    ]
    
    for metric_key, metric_name in metrics_to_compare:
        guided_value = guided_agg.get(metric_key, 0)
        baseline_value = baseline_agg.get(metric_key, 0)
        
        if baseline_value > 0:
            improvement_ratio = guided_value / baseline_value
            improvement_percentage = ((guided_value - baseline_value) / baseline_value) * 100
        else:
            improvement_ratio = float('inf') if guided_value > 0 else 1.0
            improvement_percentage = float('inf') if guided_value > 0 else 0.0
        
        comparison[metric_key] = {
            'guided_value': guided_value,
            'baseline_value': baseline_value,
            'absolute_improvement': guided_value - baseline_value,
            'improvement_ratio': improvement_ratio,
            'improvement_percentage': improvement_percentage,
            'metric_name': metric_name
        }
    
    return comparison


def _calculate_per_product_improvements(guided_products: Dict, baseline_products: Dict) -> Dict:
    """Calculate improvement metrics for each individual product"""
    per_product_improvements = {}
    
    # Find products that exist in both guided and baseline
    common_products = set(guided_products.keys()) & set(baseline_products.keys())
    guided_only_products = set(guided_products.keys()) - set(baseline_products.keys())
    baseline_only_products = set(baseline_products.keys()) - set(guided_products.keys())
    
    for product in common_products:
        guided_data = guided_products[product]
        baseline_data = baseline_products[product]
        
        improvements = {
            'product_in_both': True,
            'guided_total_samples': guided_data.get('total_samples', 0),
            'baseline_total_samples': baseline_data.get('total_samples', 0)
        }
        
        # Exact match improvement
        guided_exact = guided_data.get('exact_match_found', False)
        baseline_exact = baseline_data.get('exact_match_found', False)
        improvements['exact_match_improvement'] = guided_exact and not baseline_exact
        improvements['exact_match_maintained'] = guided_exact and baseline_exact
        improvements['exact_match_lost'] = not guided_exact and baseline_exact
        
        # Class accuracy improvements
        guided_class_samples = guided_data.get('total_class_correct_samples', 0)
        baseline_class_samples = baseline_data.get('total_class_correct_samples', 0)
        improvements['class_correct_samples_improvement'] = guided_class_samples - baseline_class_samples
        improvements['class_correct_samples_improved'] = guided_class_samples > baseline_class_samples
        
        # RXN name improvements
        guided_rxn_samples = guided_data.get('total_rxn_name_correct_samples', 0)
        baseline_rxn_samples = baseline_data.get('total_rxn_name_correct_samples', 0)
        improvements['rxn_name_samples_improvement'] = guided_rxn_samples - baseline_rxn_samples
        improvements['rxn_name_samples_improved'] = guided_rxn_samples > baseline_rxn_samples
        
        # Round trip improvements
        guided_rt_samples = guided_data.get('total_round_trip_correct_samples', 0)
        baseline_rt_samples = baseline_data.get('total_round_trip_correct_samples', 0)
        improvements['round_trip_samples_improvement'] = guided_rt_samples - baseline_rt_samples
        improvements['round_trip_samples_improved'] = guided_rt_samples > baseline_rt_samples
        
        # Tanimoto improvements
        guided_tanimoto_starting = guided_data.get('best_tanimoto_to_starting_material')
        baseline_tanimoto_starting = baseline_data.get('best_tanimoto_to_starting_material')
        if guided_tanimoto_starting is not None and baseline_tanimoto_starting is not None:
            improvements['tanimoto_starting_improvement'] = guided_tanimoto_starting - baseline_tanimoto_starting
            improvements['tanimoto_starting_improved'] = guided_tanimoto_starting > baseline_tanimoto_starting
        
        guided_tanimoto_target = guided_data.get('best_tanimoto_to_target')
        baseline_tanimoto_target = baseline_data.get('best_tanimoto_to_target')
        if guided_tanimoto_target is not None and baseline_tanimoto_target is not None:
            improvements['tanimoto_target_improvement'] = guided_tanimoto_target - baseline_tanimoto_target
            improvements['tanimoto_target_improved'] = guided_tanimoto_target > baseline_tanimoto_target
        
        per_product_improvements[product] = improvements
    
    # Handle products only in guided or baseline
    for product in guided_only_products:
        per_product_improvements[product] = {
            'product_in_both': False,
            'only_in_guided': True,
            'guided_data': guided_products[product]
        }
    
    for product in baseline_only_products:
        per_product_improvements[product] = {
            'product_in_both': False,
            'only_in_baseline': True,
            'baseline_data': baseline_products[product]
        }
    
    return per_product_improvements


def _calculate_product_improvement_stats(per_product_improvements: Dict) -> Dict:
    """Calculate aggregate statistics about product-level improvements"""
    
    # Filter to products that exist in both guided and baseline
    common_products = {k: v for k, v in per_product_improvements.items() if v.get('product_in_both', False)}
    total_common_products = len(common_products)
    
    if total_common_products == 0:
        return {'total_common_products': 0}
    
    stats = {
        'total_common_products': total_common_products,
        'products_only_in_guided': len([v for v in per_product_improvements.values() if v.get('only_in_guided', False)]),
        'products_only_in_baseline': len([v for v in per_product_improvements.values() if v.get('only_in_baseline', False)])
    }
    
    # Exact match improvements
    exact_match_improved = sum(1 for v in common_products.values() if v.get('exact_match_improvement', False))
    exact_match_maintained = sum(1 for v in common_products.values() if v.get('exact_match_maintained', False))
    exact_match_lost = sum(1 for v in common_products.values() if v.get('exact_match_lost', False))
    
    stats['exact_match_improved_count'] = exact_match_improved
    stats['exact_match_maintained_count'] = exact_match_maintained
    stats['exact_match_lost_count'] = exact_match_lost
    stats['fraction_products_exact_match_improved'] = exact_match_improved / total_common_products
    stats['fraction_products_exact_match_maintained'] = exact_match_maintained / total_common_products
    stats['fraction_products_exact_match_lost'] = exact_match_lost / total_common_products
    
    # Class accuracy improvements
    class_improved = sum(1 for v in common_products.values() if v.get('class_correct_samples_improved', False))
    stats['class_improved_count'] = class_improved
    stats['fraction_products_class_improved'] = class_improved / total_common_products
    
    # Average improvements
    class_improvements = [v.get('class_correct_samples_improvement', 0) for v in common_products.values()]
    stats['avg_class_samples_improvement_per_product'] = sum(class_improvements) / len(class_improvements) if class_improvements else 0
    
    # RXN name improvements  
    rxn_name_improved = sum(1 for v in common_products.values() if v.get('rxn_name_samples_improved', False))
    stats['rxn_name_improved_count'] = rxn_name_improved
    stats['fraction_products_rxn_name_improved'] = rxn_name_improved / total_common_products
    
    rxn_name_improvements = [v.get('rxn_name_samples_improvement', 0) for v in common_products.values()]
    stats['avg_rxn_name_samples_improvement_per_product'] = sum(rxn_name_improvements) / len(rxn_name_improvements) if rxn_name_improvements else 0
    
    # Round trip improvements
    round_trip_improved = sum(1 for v in common_products.values() if v.get('round_trip_samples_improved', False))
    stats['round_trip_improved_count'] = round_trip_improved
    stats['fraction_products_round_trip_improved'] = round_trip_improved / total_common_products
    
    round_trip_improvements = [v.get('round_trip_samples_improvement', 0) for v in common_products.values()]
    stats['avg_round_trip_samples_improvement_per_product'] = sum(round_trip_improvements) / len(round_trip_improvements) if round_trip_improvements else 0
    
    # Tanimoto improvements
    tanimoto_starting_improvements = [v.get('tanimoto_starting_improvement') for v in common_products.values() if v.get('tanimoto_starting_improvement') is not None]
    if tanimoto_starting_improvements:
        stats['avg_tanimoto_starting_improvement_per_product'] = sum(tanimoto_starting_improvements) / len(tanimoto_starting_improvements)
        stats['tanimoto_starting_improved_count'] = sum(1 for imp in tanimoto_starting_improvements if imp > 0)
    
    tanimoto_target_improvements = [v.get('tanimoto_target_improvement') for v in common_products.values() if v.get('tanimoto_target_improvement') is not None]
    if tanimoto_target_improvements:
        stats['avg_tanimoto_target_improvement_per_product'] = sum(tanimoto_target_improvements) / len(tanimoto_target_improvements)
        stats['tanimoto_target_improved_count'] = sum(1 for imp in tanimoto_target_improvements if imp > 0)
    
    return stats

def process_results_df(results: Dict[str, pd.DataFrame]) -> Dict[str, pd.DataFrame]:
    """Process the results dataframe"""
    for exp_name, steps in results.items():
        for step_idx, df in steps.items():
            df['topk'] = df.apply(lambda x: compare_reactant_smiles(x['true_reactants'], x['reactant_predictions']), axis=1)
            results[exp_name][step_idx] = df
    return results

def generate_evaluation_summary(project_root: str, experiment_dir: str, experiment_group: str,
                              experiment_filters: Dict = None, baseline_experiment_group: str = None,
                              baseline_experiment_filters: Dict = None,
                              use_starting_material: bool = False,
                              max_steps: int = 16,
                              process_df: bool = False,
                              true_routes_path: str = None) -> Dict:
    """
    Generate a complete evaluation summary for manual synthesis experiments.
    
    Args:
        project_root: Root directory of the project
        experiment_subdir: Subdirectory containing experiments
        experiment_filters: Filters to select specific experiments
        use_starting_material: Whether to evaluate starting material constraints
        max_steps: Maximum steps to consider per route
        classifier_checkpoint_path: Path to classifier checkpoint
        true_routes_path: Path to ground truth routes (relative to project_root/data/)
    
    Returns:
        Dict containing all evaluation metrics
    """
    # Load ground truth routes
    # if true_routes_path is None:
    #     true_routes_path = 'uspto_190/in_json/test_linear_routes_with_tanimoto_weight1.json'
    
    full_true_routes_path = os.path.join(project_root, 'data', true_routes_path)
    with open(full_true_routes_path, 'r') as f:
        true_routes = json.load(f)
    
    # Load experiment results
    results = load_experiment_results(project_root, experiment_dir, experiment_group, experiment_filters)
    #results = load_single_step_results(experiment_dir)
    print(f'Loaded {len(results)} experiments')

    # if process_df:
    #     results = process_results_df(results)

    route_completion = calculate_route_completion_rates(
        results, true_routes, use_starting_material, max_steps
    )

    # TODO: compute per experiment metrics
    per_experiment_stats = calculate_per_experiment_statistics(results)
    print(f'Calculated stats for {len(per_experiment_stats)} experiments')

    # TODO: compute summary statistics
    # Compute per-product aggregation across experiments
    per_product_aggregation = calculate_per_product_aggregation(results)
    print(f'Calculated per-product metrics for {len(per_product_aggregation["per_product"])} products')
    
    # anything else?
    # return evaluation summary
    evaluation_summary = {
        'experiment_filters': experiment_filters,
        'num_experiments': len(results),
        'experiment_names': list(results.keys()),
        'per_experiment_statistics': per_experiment_stats,
        'per_product_aggregation': per_product_aggregation,
        'route_completion': route_completion
        # 'accuracy_metrics': accuracy_metrics,
        # 'route_completion': route_completion,
        # 'summary_statistics': summary_stats,
        # 'sample_accuracy_metrics': sample_accuracy_metrics
    }

    # if baseline_experiment_group is not None:
    #     print(f'Loading baseline experiments from {baseline_experiment_group}...')
    #     baseline_results = load_experiment_results(project_root, experiment_dir, baseline_experiment_group, baseline_experiment_filters)
    #     baseline_per_product_aggregation = calculate_per_product_aggregation(baseline_results)
        
    #     improvement_metrics = calculate_improvement_metrics(per_product_aggregation, baseline_per_product_aggregation)
        
    #     evaluation_summary['baseline_comparison'] = {
    #         'baseline_experiment_group': baseline_experiment_group,
    #         'baseline_experiment_filters': baseline_experiment_filters,
    #         'baseline_num_experiments': len(baseline_results),
    #         'baseline_per_product_aggregation': baseline_per_product_aggregation,
    #         'improvement_metrics': improvement_metrics
    #     }
        
    #     print(f'Calculated improvement metrics comparing to {len(baseline_results)} baseline experiments')
    
    return evaluation_summary


def generate_evaluation_summary_old2(project_root: str, experiment_dir: str, experiment_group: str,
                              experiment_filters: Dict = None,
                              use_starting_material: bool = False,
                              max_steps: int = 15,
                              process_df: bool = False,
                              classifier_checkpoint_path: str = None,
                              true_routes_path: str = None) -> Dict:
    """
    Generate a complete evaluation summary for manual synthesis experiments.
    
    Args:
        project_root: Root directory of the project
        experiment_subdir: Subdirectory containing experiments
        experiment_filters: Filters to select specific experiments
        use_starting_material: Whether to evaluate starting material constraints
        max_steps: Maximum steps to consider per route
        classifier_checkpoint_path: Path to classifier checkpoint
        true_routes_path: Path to ground truth routes (relative to project_root/data/)
    
    Returns:
        Dict containing all evaluation metrics
    """
    import pickle
    
    # Load classifier checkpoint if provided
    # classifier_checkpoint = None
    # if classifier_checkpoint_path:
    #     full_path = os.path.join(project_root, 'checkpoints', classifier_checkpoint_path)
    #     classifier_checkpoint = torch.load(full_path, map_location=torch.device('cpu'))
    
    # Load ground truth routes
    if true_routes_path is None:
        true_routes_path = 'uspto_190/in_json/test_linear_routes_with_tanimoto_weight1.json'
    
    full_true_routes_path = os.path.join(project_root, 'data', true_routes_path)
    with open(full_true_routes_path, 'r') as f:
        true_routes = json.load(f)
    
    # Load experiment results
    results = load_experiment_results(project_root, experiment_dir, experiment_group, experiment_filters)
    print(f'Loaded {len(results)} experiments')

    # Process each experiment individually with its own cache
    # if process_df:
    #     processed_results = {}
        
    #     for exp_name, df in results.items():
    #         # Check for cached processed DataFrame for this specific experiment
    #         exp_processed_dir = os.path.join(project_root, 'experiments', experiment_subdir, exp_name, 'processed')
    #         exp_processed_file = os.path.join(exp_processed_dir, 'processed_df.csv')
            
    #         if os.path.exists(exp_processed_file):
    #             print(f'Loading cached processed DataFrame for {exp_name}...')
    #             processed_df = pd.read_csv(exp_processed_file)
    #             processed_results[exp_name] = processed_df
    #         else:
    #             print(f'Processing {exp_name}...')
    #             processed_df = preprocess_results_df(df, classifier_checkpoint=None)
    #             processed_results[exp_name] = processed_df
                
    #             # Save processed DataFrame for this experiment as CSV
    #             os.makedirs(exp_processed_dir, exist_ok=True)
    #             processed_df.to_csv(exp_processed_file, index=False)
    #             print(f'Saved processed DataFrame for {exp_name}')

    #     results = processed_results

    # Calculate all metrics
    #accuracy_metrics = calculate_exact_match_accuracy(results)
    sample_accuracy_metrics = calculate_sample_level_accuracy(results)
    route_completion = calculate_route_completion_rates(
        results, true_routes, use_starting_material, max_steps
    )
    
    # Calculate summary statistics across all experiments
    summary_stats = calculate_summary_statistics(accuracy_metrics, route_completion)
    
    evaluation_summary = {
        'experiment_filters': experiment_filters,
        'num_experiments': len(results),
        'experiment_names': list(results.keys()),
        'accuracy_metrics': accuracy_metrics,
        'route_completion': route_completion,
        'summary_statistics': summary_stats,
        'sample_accuracy_metrics': sample_accuracy_metrics
    }
    
    return evaluation_summary


def generate_evaluation_summary_old(project_root: str, experiment_subdir: str,
                              experiment_filters: Dict = None,
                              use_starting_material: bool = False,
                              max_steps: int = 15,
                              classifier_checkpoint_path: str = None,
                              true_routes_path: str = None) -> Dict:
    """
    Generate a complete evaluation summary for manual synthesis experiments.
    
    Args:
        project_root: Root directory of the project
        experiment_subdir: Subdirectory containing experiments
        experiment_filters: Filters to select specific experiments
        use_starting_material: Whether to evaluate starting material constraints
        max_steps: Maximum steps to consider per route
        classifier_checkpoint_path: Path to classifier checkpoint
    
    Returns:
        Dict containing all evaluation metrics
    """
    # Load classifier checkpoint if provided
    classifier_checkpoint = None
    if classifier_checkpoint_path:
        full_path = os.path.join(project_root, 'checkpoints', classifier_checkpoint_path)
        classifier_checkpoint = torch.load(full_path, map_location=torch.device('cpu'))
    
    # Load ground truth routes
    true_routes_path = os.path.join(project_root, 'data', true_routes_path)
    with open(true_routes_path, 'r') as f:
        true_routes = json.load(f)
    
    results = load_experiment_results(project_root, experiment_subdir, experiment_filters)
    print(f'Loaded {len(results)} experiments')

    # Load experiment results
    # out_processed_results = os.path.join(project_root, 'experiments', experiment_subdir, 'processed_results')
    # if os.path.exists(out_processed_results):
    #     results = pickle.load(open(out_processed_results, 'rb'))
    # else:

    # Preprocess all results
    processed_results = {}
    for exp_name, df in results.items():
        processed_results[exp_name] = preprocess_results_df(df, classifier_checkpoint, true_routes)

    # Calculate all metrics
    accuracy_metrics = calculate_exact_match_accuracy(processed_results)
    route_completion = calculate_route_completion_rates(
        processed_results, true_routes, use_starting_material, max_steps
    )
    
    # Calculate summary statistics across all experiments
    summary_stats = calculate_summary_statistics(accuracy_metrics, route_completion)
    
    # Calculate all metrics
    evaluation_summary = {
        'experiment_filters': experiment_filters,
        'num_experiments': len(processed_results),
        'experiment_names': list(processed_results.keys()),
        'accuracy_metrics': accuracy_metrics,
        'route_completion': route_completion,
        'summary_statistics': summary_stats
    }
    
    return evaluation_summary

def print_evaluation_summary(summary: Dict):
    """
    Print a nicely formatted evaluation summary.
    """
    print(f"=== Manual Synthesis Evaluation Summary ===")
    print(f"Number of experiments: {summary['num_experiments']}")
    print(f"Experiments: {', '.join(summary['experiment_names'])}")
    print()
    
    print("=== Accuracy Metrics (Product-level) ===")
    for exp_name in summary['experiment_names']:
        print(f"\n{exp_name}:")
        metrics = summary['accuracy_metrics'].get(exp_name, {})
        for metric_name, value in metrics.items():
            print(f"  {metric_name}: {value:.4f}")
    
    print("\n=== Route Completion Metrics ===")
    for exp_name in summary['experiment_names']:
        print(f"\n{exp_name}:")
        stats = summary['route_completion'].get(exp_name, {})
        print(f"  Total routes: {stats.get('total_routes', 0)}")
        print(f"  Fully completed: {stats.get('fully_completed_routes', 0)}")
        print(f"  Completion rate: {stats.get('completion_rate', 0):.4f}")
        if stats.get('routes_with_starting_material') is not None:
            print(f"  With starting material: {stats.get('routes_with_starting_material', 0)}")


def extract_reaction_idx_from_name(exp_name: str) -> int:
    """
    Extract reaction index from experiment name using regex.
    Looks for pattern reaction(int) in the experiment name.
    """
    import re
    
    # Look for pattern reaction followed by digits
    match = re.search(r'reaction(\d+)', exp_name)
    if match:
        index = int(match.group(1))
        print(f'Extracted reaction index: {index}')
        return index
    else:
        print(f'Warning: Could not extract reaction index from {exp_name}')
        return 0

def group_experiments_by_reaction_step(results: Dict[str, pd.DataFrame]) -> Dict[int, Dict[str, pd.DataFrame]]:
    """
    Group experiments by reaction step, with all parameter combinations for each step.
    
    Args:
        results: Dict mapping experiment names to DataFrames
    
    Returns:
        Dict mapping reaction_idx to {param_combination: DataFrame}
    """
    import re
    
    grouped = defaultdict(dict)
    
    for exp_name, steps in results.items():
        # Extract reaction index
        #reaction_idx = extract_reaction_idx_from_name(exp_name)
        for step_idx, df in steps.items():
            # Create parameter key by removing reaction index from name
            # param_key = re.sub(r'_reaction\d+_', '_', exp_name)
            # param_key = re.sub(r'reaction\d+_', '', param_key)
            
            grouped[step_idx][exp_name] = df
    
    return dict(grouped)

def calculate_route_completion_rates_old(results: Dict[str, pd.DataFrame], 
                                   true_routes: List[Dict],
                                   use_starting_material: bool = False,
                                   max_steps: int = 15) -> Dict[str, Dict]:
    """
    Calculate how many routes can be completed entirely.
    
    Args:
        results: Dict mapping experiment names to DataFrames
        true_routes: List of ground truth routes
        use_starting_material: Whether to check for starting material constraints
        max_steps: Maximum number of steps to consider per route
    
    Returns:
        Dict with route completion statistics for each experiment
    """
    completion_results = {}
    
    for exp_name, df in results.items():
        route_stats = {
            'total_routes': 0,
            'fully_completed_routes': 0,
            'completion_rate': 0.0,
            'steps_solved_per_route': [],
            'routes_with_starting_material': 0 if use_starting_material else None
        }
        
        # Group experiments by reaction index (extract from experiment name)
        reaction_dfs = {}
        # This assumes your experiment names contain reaction indices
        # You might need to adjust this parsing based on your naming convention
        reaction_idx = extract_reaction_idx_from_name(exp_name)
        reaction_dfs[reaction_idx] = df
        
        for route in true_routes:
            route_stats['total_routes'] += 1
            steps_solved = 0
            route_has_starting_material = False
            
            for reaction_idx, reaction in enumerate(route['route']):
                if reaction_idx > max_steps:
                    break
                    
                true_product = reaction.split('>>')[0]
                true_reactants = reaction.split('>>')[1]
                
                # Check if we have results for this reaction step
                if reaction_idx in reaction_dfs:
                    step_df = reaction_dfs[reaction_idx]
                    product_results = step_df[step_df['product_smi'] == true_product]
                    
                    if len(product_results) > 0 and 'is_correct' in product_results.columns:
                        if product_results['is_correct'].any():
                            steps_solved += 1
                            
                            # Check starting material if needed
                            if use_starting_material and 'starting_material' in route:
                                starting_material = pick_starting_material(
                                    route['main_target'], route['starting_material']
                                )
                                correct_results = product_results[product_results['is_correct']]
                                for _, row in correct_results.iterrows():
                                    if starting_material in row['reactant_predictions']:
                                        route_has_starting_material = True
                                        break
            
            route_stats['steps_solved_per_route'].append(steps_solved)
            
            # Check if route is fully completed
            total_steps = min(len(route['route']), max_steps)
            if steps_solved == total_steps:
                route_stats['fully_completed_routes'] += 1
                if route_has_starting_material:
                    route_stats['routes_with_starting_material'] += 1
        
        # Calculate completion rate
        if route_stats['total_routes'] > 0:
            route_stats['completion_rate'] = (route_stats['fully_completed_routes'] / 
                                            route_stats['total_routes'])
        
        completion_results[exp_name] = route_stats
    
    return completion_results

def pick_starting_material(main_target: str, starting_material: str) -> str:
    """
    Pick the best starting material for a target.
    Replace this with your existing logic.
    """
    # if len(starting_materials)>1:
    #     print(f'Currently we handle only one starting material for a target')
    #starting_material = starting_materials[0]
    similarity_target_to_starting_material = get_similarity(main_target,
                                                            starting_material,
                                                            'tanimoto', 
                                                            1)
    similarity_and_num_atoms = [sm + (Chem.MolFromSmiles(sm[1]).GetNumAtoms(),) for sm in similarity_target_to_starting_material]
    # get the heaviest + most similar starting material to target
    heaviest_starting_material = sorted(similarity_and_num_atoms, key=lambda x: (x[3], x[2]), reverse=True)[0][1]

    return heaviest_starting_material

def extract_reaction_idx_from_name(exp_name: str) -> int:
    """
    Extract reaction index from experiment name using regex.
    Looks for pattern reaction(int) in the experiment name.
    """
    import re
    
    # Look for pattern reaction followed by digits
    match = re.search(r'reaction(\d+)', exp_name)
    if match:
        index = int(match.group(1))
        return index
    else:
        print(f'Warning: Could not extract reaction index from {exp_name}')
        return 0

def load_ground_truth_routes(config):
    # read ground truth route
    ground_truth_route_path = os.path.join(PROJECT_ROOT,
                                            'data',
                                            config.route_dataset.type,
                                            config.route_dataset.path)
    with open(ground_truth_route_path, 'r') as f:
        ground_truth_routes = json.load(f)
    return ground_truth_routes

def extract_reactions(pathway):
    """
    Extract all reactions from a single retrosynthetic pathway.
    
    Args:
        pathway: A single pathway dictionary (one element from your data)
    
    Returns:
        List of reaction dictionaries
    """
    reactions = []
    
    def traverse(node):
        if node.get('type') == 'reaction':
            # reactions.append({
            #     'smiles': node.get('smiles', ''),
            #     'template': node.get('template', [])
            # })
            reactions.append(node['smiles'])
        
        # Recursively process children
        if 'children' in node:
            for child in node['children']:
                traverse(child)
    
    # Start traversal
    traverse(pathway)
    return reactions

def extract_routes_for_target_desp(config, target_idx):
    routes_path = os.path.join(
        PROJECT_ROOT, 
        'experiments', 
        config.general.experiment_group,
        config.general.experiment_params,
        config.general.experiment_name,
        f'strategy_{config.search.strategy}',
        f'graphs_for_mol{target_idx}',
        'output_graph.pkl'
    )
    search_stats_path = os.path.join(
        PROJECT_ROOT,
        'experiments',
        config.general.experiment_group,
        config.general.experiment_params,
        config.general.experiment_name,
        f'strategy_{config.search.strategy}',
        f'graphs_for_mol{target_idx}',
        'search_stats.pkl'
    )
    search_stats = pickle.load(open(search_stats_path, 'rb'))
    routes = pickle.load(
        open(routes_path, 'rb')
    )
    if routes is None:
        routes = []
    routes_as_array = [extract_reactions(route) for route in routes]
    search_stats['num_unique_routes'] = len(routes_as_array)
    # TODO: count number of unique routes
    return routes_as_array, len(routes_as_array), search_stats

def extract_routes_for_target(config, target_idx):
    '''
        Extract routes for a given target index.
    '''
    if config.search.type=='desp':
        return extract_routes_for_target_desp(config, target_idx)
    elif config.search.type=='retro_star':
        return extract_routes_for_target_retro_star(config, target_idx)
    else:
        raise ValueError(f'Invalid search type: {config.search.type}')

def compute_diversity(output_graph, routes, return_packing_set=False, config=None):
    routes_objects = [output_graph.to_synthesis_graph(route) for route in routes]
    packing_set = diversity.estimate_packing_number(
        routes=routes_objects,
        distance_metric=diversity.reaction_jaccard_distance,
        radius=config.search.diversity_radius if config else 0.99 # because comparison uses ">", not ">="
    )
    if return_packing_set:
        return len(packing_set), packing_set
    else:
        return len(packing_set)

def extract_routes_for_target_retro_star(
    config,
    target_idx
):
    '''
        Extract routes for a given target index.

        Args:
            config: config object
            target_idx: index of the target molecule

        Returns:
            packing_set: list of routes in SynthesisGraph format (i.e. containing only reactions)
    '''
    print(f'======= processing molecule {target_idx}')
    output_graph_dir = os.path.join(
        PROJECT_ROOT,
        'experiments',
        config.general.experiment_group,
        config.general.experiment_params,
        config.general.experiment_name, # experiment subfolder
        f'strategy_{config.search.strategy}',
        f'graphs_for_mol{target_idx}'
    )
    output_graph_path = os.path.join(
        output_graph_dir,
        'output_graph.pkl'
    )
    search_stats_path = os.path.join(
        output_graph_dir,
        'search_stats.pkl'
    )
    #try:
    with open(output_graph_path, 'rb') as f:
        output_graph = pickle.load(f)
    with open(search_stats_path, 'rb') as f:
        search_stats = pickle.load(f)
    # except Exception as e:
    #     print(f'======= Graph {output_graph_path} not found')
    #     return None, None, None
    ## Extract the routes simply in the order they were found.
    print(f'======= extracting routes')
    # TODO: append the target smiles to the metrics dictionary
    start_time = time.time()
    routes = list(
        iter_routes_time_order(
            output_graph,
            max_routes=config.search.max_routes_to_extract
        )
    )
    print(f'Extracted {len(routes)} routes in {time.time()-start_time} seconds')
    # get unique routes
    _, routes_objects = compute_diversity(
        output_graph,
        routes,
        return_packing_set=True,
        config=config
    )
    # TODO: very temporary hack, fix this to be able to handle routes object directly instead of packing set.
    # For now this computes the right num_unique_routes, while earlier code transforms routes to packing set.
    num_unique_routes, _ = compute_diversity(
        output_graph,
        routes,
        return_packing_set=True,
        config=None
    )
    print(f'======= {num_unique_routes} unique routes found with {config.search.diversity_radius} radius')
    print(f'======= original num of routes: {len(routes)}')
    search_stats['num_unique_routes'] = num_unique_routes
    return routes_objects, num_unique_routes, search_stats

def get_rxn_insight_info(rxn_smi):
    try:
        rxn = Reaction(rxn_smi)
        ri = rxn.get_reaction_info()
        return ri
    except Exception as e:
        print(f'error: {e}')
        return None

def get_rxn_insight_info_for_list(rxns_smiles):
    rxn_insight_info = []
    for rxn_smi in rxns_smiles:
        rxn_insight_info.append(get_rxn_insight_info(rxn_smi))
    return rxn_insight_info

def compute_topk_accuracy(results, true_reactants, topk={1: 0, 3: 0, 5: 0, 10: 0}):
    '''
    results: list of strings
    true_reactants: string representing true reactants
    topk: dict of ints
    '''
    for key in topk.keys():
        if true_reactants in results[:key]:
            topk[key] += 1
            break
    return topk

def get_round_trip_results(predictions, config, batch_size=32):
    forward_model_dir = os.path.join(
        PROJECT_ROOT,
        'checkpoints',
        config.single_step_evaluation.forward_model_dir
    )
    # NOTE: this is a forward model, so we input the reactants as our product, and the product in the reactants field
    print(f'============ get_round_trip_results,'+\
          f'config.single_step_evaluation.num_results_forward: {config.single_step_evaluation.num_results_forward},'+\
          f' config.single_step_evaluation.num_augmentations: {config.single_step_evaluation.num_augmentations}')
    root_aligned_forward = RootAlignedForwardModel(use_cache=True,
                                                    num_augmentations=config.single_step_evaluation.num_augmentations,
                                                    default_num_results=config.single_step_evaluation.num_results_forward, # 10
                                                    model_dir=forward_model_dir,
                                                    config=config)
    #round_trip_results = root_aligned_forward(results)
    # process predictions in batches
    round_trip_results = []
    print(f'============ get_round_trip_results, len(predictions): {len(predictions)}, batch_size: {batch_size}')
    for i in range(0, len(predictions), batch_size):
        batch = predictions[i:i+batch_size]
        try:
            out = root_aligned_forward([Molecule(mol) for mol in batch],
                                    num_results=config.single_step_evaluation.num_results_forward)
        except Exception as e:
            print(f'Error getting round trip results: {e}')
            out = []
        round_trip_results.extend(out)
    round_trip_results_smiles = turn_results_to_mol_smiles(round_trip_results)
    return round_trip_results_smiles

def get_retrosynthetic_results(config, product_smi, conditional_starting_material=None, conditional_target=None):
    mol = Molecule(product_smi)
    retrosynthetic_model_dir = os.path.join(
        PROJECT_ROOT,
        'checkpoints',  
        config.single_step_model.model_dir
    )
    # TODO: add other models here based on config.single_step_model.name
    if config.single_step_model.model_type == 'retroknn':
        model = RetroKNNModel(
            use_cache=True,
            default_num_results=config.single_step_model.default_num_results,
            model_dir=retrosynthetic_model_dir,
            device=device
        )
    elif config.single_step_model.model_type == 'rootaligned_original':
        model = RootAlignedModel(
            use_cache=True,
            num_augmentations=config.single_step_model.num_augmentations,
            default_num_results=config.single_step_model.default_num_results, # 10
            model_dir=retrosynthetic_model_dir
        )
    elif config.single_step_model.model_type == 'rootaligned':
        model = RootAlignedFixedModel(use_cache=True,
                                        num_augmentations=config.single_step_model.num_augmentations,
                                        default_num_results=config.single_step_model.default_num_results, # 10
                                        model_dir=retrosynthetic_model_dir,
                                        config=config,
                                        conditional_starting_material=conditional_starting_material,
                                        conditional_target=conditional_target)
    elif config.single_step_model.model_type == 'neuralsym':
        model_path = os.path.join(PROJECT_ROOT,
                                'checkpoints',
                                config.single_step_model.model_dir)
        templates_path = os.path.join(PROJECT_ROOT,
                                        'data',
                                        'desp_data',
                                        'idx2template_retro.json')
        model = NeuralSymPredictor(use_cache=True,
                                   default_num_results=config.single_step_model.default_num_results)
        model.setup(model_path, templates_path)
    else:
        raise ValueError(f'Invalid model name: {config.single_step_model.model_type}')
    results = model([mol], num_results=config.single_step_model.default_num_results)
    results_smiles = turn_results_to_mol_smiles(results)
    return results_smiles


