'''
    We use this script to run the single-step synthesis of a given molecule in the following steps:
    1. get the predictions from the single-step model with/without guidance
    2. evaluate the predictions with: 1) round-trip accuracy, 2) rxn insight, 3) topk accuracy, 4) classifier score
    3. save results with scores to a well named file
'''
from networkx.classes import non_neighbors
import torch
import os
import hydra
import pandas as pd
import json

# Instead of direct import, use conditional import
import sys
from unittest.mock import MagicMock

def safe_import_helpers():
    # Mock the problematic modules first
    sys.modules['rdkit.Chem.Draw.rdMolDraw2D'] = MagicMock()
    sys.modules['rdkit.Chem.Draw'] = MagicMock()
    
    from multiguide.evaluation.helpers import get_retrosynthetic_results, get_rxn_insight_info, \
                                            get_round_trip_results, get_classifier_score, \
                                            compute_topk_accuracy, compare_reactant_smiles
    return get_retrosynthetic_results, get_rxn_insight_info, get_round_trip_results, get_classifier_score, compute_topk_accuracy, compare_reactant_smiles

# Use the safe import
get_retrosynthetic_results, get_rxn_insight_info, get_round_trip_results, get_classifier_score, compute_topk_accuracy, compare_reactant_smiles = safe_import_helpers()

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import remove_dative_bonds, class_to_idx, get_tanimoto

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@hydra.main(config_path='../configs', config_name='config.yaml')
def process_and_save_one_molecule(config):
    '''
        This function runs the single step evaluation for one molecule and processes and saves the results.
    '''
    # run single step evaluation for one molecule
    df_dict = get_results_and_evaluate_for_one_molecule(config)
    # process and save results
    df = pd.DataFrame(df_dict)
    # reorganize rxn_insight into a list per key in rxn_insight_info
    first_dict = next((item for item in df_dict['rxn_insight_info'] if item is not None), None)
    if first_dict:
        for k in first_dict.keys():
            df[f'rxn_insight_{k}'] = [rxn_insight[k] if rxn_insight is not None else -1 for rxn_insight in df_dict['rxn_insight_info']]
    # reorganize topk into a list per key in topk
    for k in df_dict['topk'][0].keys():
        df[f'topk_{k}'] = [df_dict['topk'][i][k] for i in range(len(df_dict['topk']))]
    for k in df_dict['round_trip_accuracy'][0].keys():
        df[f'round_trip_accuracy_{k}'] = [df_dict['round_trip_accuracy'][i][k] for i in range(len(df_dict['round_trip_accuracy']))]
    experiment_dir = f'{config.classifier_guidance.experiment_name.split("_")[0]}_target{config.classifier_guidance.target_class_index}'
    out_dir = os.path.join(PROJECT_ROOT, 
                        'experiments',
                        experiment_dir)
    os.makedirs(out_dir, exist_ok=True)
    df.to_csv(os.path.join(out_dir, f'{config.classifier_guidance.experiment_name}.csv'), index=False)
    print(f'======== Df saved to {os.path.join(out_dir, f"{config.classifier_guidance.experiment_name}.csv")}')

def get_results_and_evaluate_for_one_molecule(
    config,
    conditional_starting_material=None,
    conditional_target=None
):
    '''
        This function runs the single step evaluation for one molecule.
    '''
    config.classifier_guidance.target_class_index = config.single_step_evaluation.rxn_class
    target_property = config.single_step_evaluation.rxn_class
    reactant_predictions = get_retrosynthetic_results(
        config,
        config.single_step_evaluation.product_smi,
        conditional_starting_material=conditional_starting_material,
        conditional_target=conditional_target
    )
    reactant_predictions = reactant_predictions[0]
    df = single_step_evaluation_of_one_molecule(
        config=config,
        reactant_predictions=reactant_predictions,
        product_smi=config.single_step_evaluation.product_smi,
        true_reactants=config.single_step_evaluation.true_reactants,
        target_property=target_property
    )
    return df

def single_step_evaluation_of_one_molecule(
    config,
    reactant_predictions,
    product_smi,
    true_reactants,
    target_property
):
    '''
        This function evaluates the single step evaluation for one molecule.

        config: the config object
        reactant_predictions: a list of reactant predictions
        product_smi: the product SMILES string
        true_reactants: the true reactants SMILES string
    '''
    # evaluate with classifier score
    # remove dative bonds
    # TODO: remove this side effect later
    #if not config.classifier_guidance.as_regression:
    path = os.path.join(
        PROJECT_ROOT,
        'data', 
        'desp_data', 
        'canon_building_block_mol2idx_no_isotope.json'
    )
    bbs = json.load(open(path, 'r', encoding='utf-8'))
    true_tanimoto_to_starting_material = get_tanimoto(
        true_reactants,
        config.single_step_evaluation.original_starting_material
    )
    if config.single_step_evaluation.original_target:
        true_tanimoto_to_target = get_tanimoto(
            true_reactants,
            config.single_step_evaluation.original_target
    )
    else:
        true_tanimoto_to_target = -1
    config.classifier_guidance.target_class_index = target_property
    reactant_predictions = remove_dative_bonds(reactant_predictions)
    if len(reactant_predictions)==0:
        print(f'Found no reactant predictions for product {product_smi}')
        df_dict = {}
    else:
        if config.single_step_evaluation.compute_classifier_score:
            output, confidence = get_classifier_score(reactant_predictions, config)
        else:
            output = None
            confidence = None
        print(f'classifier score: {output}, confidence: {confidence}')
        results_as_rxn_smiles = [
            reactant_prediction + '>>' + product_smi
            for reactant_prediction in reactant_predictions
        ]
        #reactant_smi_as_rxn_smiles = reactant_smi + '>>' + product_smi
        rxn_insight_info = get_rxn_insight_info(results_as_rxn_smiles)
        print(f'rxn insight info: {rxn_insight_info}')
        # update df with rxn insight info
        # NOTE: {1: 0} because we're checking each reactant individually.
        if true_reactants:
            # automatically returns a dict for each reactant prediction (a list of dicts)
            topk = [
                compute_topk_accuracy([pred], true_reactants, topk={1: 0})
                for pred in reactant_predictions
            ]
        else:
            topk = [-1]*len(reactant_predictions)
        # adapt the code from rsmiles
        round_trip_results = get_round_trip_results(
            reactant_predictions,
            config
        )
        # NOTE: {1: 0, 3: 0, 5: 0, 10: 0} because we get a list of potential products from the round trip model
        # round_trip_results is a list of round_trip predictions (i.e. products) for each reactant prediction
        round_trip_accuracy = [
            compute_topk_accuracy(
                round_trip_result, 
                product_smi, topk={1: 0, 3: 0, 5: 0, 10: 0}
            )
            for round_trip_result in round_trip_results
        ]
        print(f'round trip accuracy: {round_trip_accuracy}')
        # add all results to a df and save in a file
        df_dict = {}
        df_dict['product_smi'] = [product_smi]*len(reactant_predictions)
        if true_reactants:
            df_dict['true_reactants'] = [true_reactants]*len(reactant_predictions)
        else:
            df_dict['true_reactants'] = [-1]*len(reactant_predictions)
        df_dict['true_class'] = [target_property]*len(reactant_predictions)
        df_dict['reactant_predictions'] = reactant_predictions
        df_dict['classifier_property'] = [
            config.classifier_guidance.property
        ]*len(reactant_predictions)
        if confidence is not None:
            df_dict['classifier_score'] = confidence.tolist()
        else:
            df_dict['classifier_score'] = [-1]*len(reactant_predictions)
        if output is not None:
            df_dict['classifier_output'] = output.tolist()
        else:
            df_dict['classifier_output'] = [-1]*len(reactant_predictions)
        df_dict['round_trip_results'] = round_trip_results
        if rxn_insight_info is not None:
            df_dict['rxn_insight_NAME'] = [
                info['NAME']
                if info is not None
                else ''
                for info in rxn_insight_info
            ]
            df_dict['pred_class'] = [
                class_to_idx[info['CLASS']]
                if info is not None
                else -1
                for info in rxn_insight_info
            ]
        else:
            df_dict['rxn_insight_NAME'] = [-1]*len(reactant_predictions)
            df_dict['pred_class'] = [-1]*len(reactant_predictions)
        df_dict['true_tanimoto_to_target'] = [true_tanimoto_to_target]*len(reactant_predictions)
        if config.single_step_evaluation.original_target:
            df_dict['pred_tanimoto_to_target'] = [
                get_tanimoto(
                    reactant_prediction,
                    config.single_step_evaluation.original_target
                )
                for reactant_prediction in reactant_predictions
            ]
        else:
            df_dict['pred_tanimoto_to_target'] = [-1]*len(reactant_predictions)
        df_dict['true_tanimoto_to_starting_material'] = [true_tanimoto_to_starting_material]*len(reactant_predictions)
        df_dict['pred_tanimoto_to_starting_material'] = [
            get_tanimoto(
                reactant_prediction,
                config.single_step_evaluation.original_starting_material
            )
            for reactant_prediction in reactant_predictions
        ]
        df_dict['topk_detailed'] = topk
        df_dict['topk'] = [
            compare_reactant_smiles(true_reactants, reactant_prediction)
            for reactant_prediction in reactant_predictions
        ]
        df_dict['round_trip_accuracy_detailed'] = round_trip_accuracy
        df_dict['round_trip_accuracy'] = [
            product_smi in round_trip_results_for_one_pred_reactant 
            for round_trip_results_for_one_pred_reactant in round_trip_results
        ]
        df_dict['rxn_insight_info'] = rxn_insight_info
        df_dict['all_pred_reactants_are_bbs'] = [
            all(m in bbs for m in reactant_prediction.split('.'))
            for reactant_prediction in reactant_predictions
        ]
    return df_dict

if __name__ == '__main__':
    process_and_save_one_molecule()