from ChemCoTBench.eval_moledit import evaluate_moledit_score, record_moledit_results
from ChemCoTBench.eval_molopt import evaluate_molopt_score, record_molopt_results
from ChemCoTBench.eval_molund import evaluate_molund_score, record_molund_results
from ChemCoTBench.eval_rxn import evaluate_rxn_score, record_rxn_score

def eval_all_ChemCoTBench(log_name, dataset_path, logs_dir, results_dir, num_samples = 1):
    if 'moledit' in log_name:
        evaluate_moledit_score(log_name, f'{dataset_path}/mol_edit', logs_dir, results_dir, num_samples)
    else:
        evaluate_molopt_score(log_name, f'{dataset_path}/mol_opt', logs_dir, results_dir, num_samples)
        evaluate_molund_score(log_name, f'{dataset_path}/mol_und', logs_dir, results_dir, num_samples)
        evaluate_rxn_score(log_name, f'{dataset_path}/rxn', logs_dir, results_dir, num_samples)
        
def record_all_ChemCoTBench(log_name, dataset_path, logs_dir, results_dir, num_samples = 1):
    if 'moledit' in log_name:
        record_moledit_results(log_name, f'{dataset_path}/mol_edit', logs_dir, results_dir, num_samples)
    else:
        record_molopt_results(log_name, f'{dataset_path}/mol_opt', logs_dir, results_dir, num_samples)
        record_molund_results(log_name, f'{dataset_path}/mol_und', logs_dir, results_dir, num_samples)
        record_rxn_score(log_name, f'{dataset_path}/rxn', logs_dir, results_dir, num_samples)