'''
    We use this script to run the single-step synthesis of a given molecule in the following steps:
    1. get the predictions from the single-step model with/without guidance
    2. evaluate the predictions with: 
        1) round-trip accuracy, 
        2) rxn insight, 
        3) topk accuracy, 
        4) classifier score
    3. save results with scores to a well named file
'''
import os
from unittest.mock import MagicMock
import sys
from pathlib import Path
import hydra
import pandas as pd
import torch

sys.path.insert(0, str(Path(__file__).parent.parent))
# Mock the problematic RDKit drawing modules
sys.modules['rdkit.Chem.Draw.rdMolDraw2D'] = MagicMock()
sys.modules['rdkit.Chem.Draw'] = MagicMock()

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import get_batch
from multiguide.dataset.helpers import parse_batch_to_reaction_data, update_config_for_reaction, flatten_predictions
from multiguide.evaluation.helpers import get_results_and_evaluate_for_one_molecule

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@hydra.main(config_path='../configs', config_name='config.yaml')
def evaluate_single_step_model(config):
    '''
        This function runs the single step evaluation for one molecule 
        and processes and saves the results.
    '''
    # run single step evaluation for one molecule
    start_idx = config.single_step_evaluation.start_idx
    end_idx = config.single_step_evaluation.end_idx
    all_rows = []
    batch = get_batch(config)
    reactions = parse_batch_to_reaction_data(
        batch,
        start_idx
    )
    all_rows = []
    for reaction in reactions:
        print(f'Idx {reaction.batch_index}')
        print(f'conditional_starting_material: {reaction.conditional_starting_material}')
        print(f'conditional_target: {reaction.conditional_target}')
        # Update config with current reaction
        update_config_for_reaction(config, reaction)
        # Get predictions
        df_dict = get_results_and_evaluate_for_one_molecule(
            config,
            conditional_starting_material=reaction.conditional_starting_material,
            conditional_target=reaction.conditional_target
        )
        if not df_dict:
            print(f'No predictions found for {reaction.reactants}>>{reaction.product}, '
                  f'idx: {reaction.batch_index}')
            continue
        # Flatten predictions into rows
        rows = flatten_predictions(
            df_dict,
            reaction,
            config.classifier_guidance.dataset.separator
        )
        all_rows.extend(rows)
    df = pd.DataFrame(all_rows)
    out_dir = os.path.join(
        PROJECT_ROOT,
        'experiments', 
        config.general.experiment_group,
        config.general.experiment_params,
        config.general.experiment_name
    )
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(
        out_dir,
        f'start{start_idx}_end{end_idx}.csv'
    )
    df.to_csv(out_path, index=False)
    print(f'======== df saved to {out_path}')

if __name__ == '__main__':
    evaluate_single_step_model()
