'''
    We use this script to run the single-step synthesis of a given molecule in the following steps:
    1. get the predictions from the single-step model with/without guidance
    2. evaluate the predictions with: 
        1) round-trip accuracy, 
        2) rxn insight, 
        3) topk accuracy, 
        4) classifier score
    3. save results with scores to a well named file
'''
import os
from unittest.mock import MagicMock
import sys
from pathlib import Path
import hydra
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import DataStructs
import json

from dataclasses import dataclass
from typing import Optional, List, Dict, Any

sys.path.insert(0, str(Path(__file__).parent.parent))
# Mock the problematic RDKit drawing modules
sys.modules['rdkit.Chem.Draw.rdMolDraw2D'] = MagicMock()
sys.modules['rdkit.Chem.Draw'] = MagicMock()

from multiguide.helpers import PROJECT_ROOT
from multiguide.evaluation.helpers import get_results_and_evaluate_for_one_molecule

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_retrosynthetic_results(config, product_smi, conditional_starting_material=None, conditional_target=None):
    mol = Molecule(product_smi)
    retrosynthetic_model_dir = os.path.join(
        PROJECT_ROOT,
        'checkpoints',  
        config.single_step_model.model_dir
    )
    # TODO: add other models here based on config.single_step_model.name
    if config.single_step_model.model_type == 'retroknn':
        model = RetroKNNModel(
            use_cache=True,
            default_num_results=config.single_step_model.default_num_results,
            model_dir=retrosynthetic_model_dir,
            device=device
        )
    elif config.single_step_model.model_type == 'rootaligned_original':
        model = RootAlignedModel(
            use_cache=True,
            num_augmentations=config.single_step_model.num_augmentations,
            default_num_results=config.single_step_model.default_num_results, # 10
            model_dir=retrosynthetic_model_dir
        )
    elif config.single_step_model.model_type == 'rootaligned':
        model = RootAlignedFixedModel(use_cache=True,
                                        num_augmentations=config.single_step_model.num_augmentations,
                                        default_num_results=config.single_step_model.default_num_results, # 10
                                        model_dir=retrosynthetic_model_dir,
                                        config=config,
                                        conditional_starting_material=conditional_starting_material,
                                        conditional_target=conditional_target)
    elif config.single_step_model.model_type == 'neuralsym':
        model_path = os.path.join(PROJECT_ROOT,
                                'checkpoints',
                                config.single_step_model.model_dir)
        templates_path = os.path.join(PROJECT_ROOT,
                                        'data',
                                        'desp_data',
                                        'idx2template_retro.json')
        model = NeuralSymPredictor(use_cache=True,
                                   default_num_results=config.single_step_model.default_num_results)
        model.setup(model_path, templates_path)
    else:
        raise ValueError(f'Invalid model name: {config.single_step_model.model_type}')
    results = model([mol], num_results=config.single_step_model.default_num_results)
    results_smiles = turn_results_to_mol_smiles(results)
    return results_smiles

def single_step_evaluation_of_one_molecule(
    config,
    reactant_predictions,
    product_smi,
    true_reactants,
    target_property
):
    '''
        This function evaluates the single step evaluation for one molecule.

        config: the config object
        reactant_predictions: a list of reactant predictions
        product_smi: the product SMILES string
        true_reactants: the true reactants SMILES string
    '''
    # evaluate with classifier score
    # remove dative bonds
    # TODO: remove this side effect later
    #if not config.classifier_guidance.as_regression:
    path = os.path.join(
        PROJECT_ROOT,
        'data', 
        'desp_data', 
        'canon_building_block_mol2idx_no_isotope.json'
    )
    bbs = json.load(open(path, 'r', encoding='utf-8'))
    true_tanimoto_to_starting_material = get_tanimoto(
        true_reactants,
        config.single_step_evaluation.original_starting_material
    )
    if config.single_step_evaluation.original_target:
        true_tanimoto_to_target = get_tanimoto(
            true_reactants,
            config.single_step_evaluation.original_target
    )
    else:
        true_tanimoto_to_target = -1
    config.classifier_guidance.target_class_index = target_property
    reactant_predictions = remove_dative_bonds(reactant_predictions)
    if len(reactant_predictions)==0:
        print(f'Found no reactant predictions for product {product_smi}')
        df_dict = {}
    else:
        if config.single_step_evaluation.compute_classifier_score:
            output, confidence = get_classifier_score(reactant_predictions, config)
        else:
            output = None
            confidence = None
        print(f'classifier score: {output}, confidence: {confidence}')
        results_as_rxn_smiles = [
            reactant_prediction + '>>' + product_smi
            for reactant_prediction in reactant_predictions
        ]
        rxn_insight_info = [get_rxn_insight_info(result) for result in results_as_rxn_smiles]
        print(f'rxn insight info: {rxn_insight_info}')
        # update df with rxn insight info
        # NOTE: {1: 0} because we're checking each reactant individually.
        if true_reactants:
            # automatically returns a dict for each reactant prediction (a list of dicts)
            topk = [
                compute_topk_accuracy([pred], true_reactants, topk={1: 0})
                for pred in reactant_predictions
            ]
        else:
            topk = [-1]*len(reactant_predictions)
        # adapt the code from rsmiles
        round_trip_results = get_round_trip_results(
            reactant_predictions,
            config
        )
        # NOTE: {1: 0, 3: 0, 5: 0, 10: 0} because we get a list of potential products from the round trip model
        # round_trip_results is a list of round_trip predictions (i.e. products) for each reactant prediction
        round_trip_accuracy = [
            compute_topk_accuracy(
                round_trip_result,
                product_smi, topk={1: 0, 3: 0, 5: 0, 10: 0}
            )
            for round_trip_result in round_trip_results
        ]
        print(f'round trip accuracy: {round_trip_accuracy}')
        # add all results to a df and save in a file
        df_dict = {}
        df_dict['product_smi'] = [product_smi]*len(reactant_predictions)
        if true_reactants:
            df_dict['true_reactants'] = [true_reactants]*len(reactant_predictions)
        else:
            df_dict['true_reactants'] = [-1]*len(reactant_predictions)
        df_dict['true_class'] = [target_property]*len(reactant_predictions)
        df_dict['reactant_predictions'] = reactant_predictions
        df_dict['classifier_property'] = [
            config.classifier_guidance.property
        ]*len(reactant_predictions)
        if confidence is not None:
            df_dict['classifier_score'] = confidence.tolist()
        else:
            df_dict['classifier_score'] = [-1]*len(reactant_predictions)
        if output is not None:
            df_dict['classifier_output'] = output.tolist()
        else:
            df_dict['classifier_output'] = [-1]*len(reactant_predictions)
        df_dict['round_trip_results'] = round_trip_results
        if rxn_insight_info is not None:
            df_dict['rxn_insight_NAME'] = [
                info['NAME']
                if info is not None
                else ''
                for info in rxn_insight_info
            ]
            df_dict['pred_class'] = [
                class_to_idx[info['CLASS']]
                if info is not None
                else -1
                for info in rxn_insight_info
            ]
        else:
            df_dict['rxn_insight_NAME'] = [-1]*len(reactant_predictions)
            df_dict['pred_class'] = [-1]*len(reactant_predictions)
        df_dict['true_tanimoto_to_target'] = [true_tanimoto_to_target]*len(reactant_predictions)
        if config.single_step_evaluation.original_target:
            df_dict['pred_tanimoto_to_target'] = [
                get_tanimoto(
                    reactant_prediction,
                    config.single_step_evaluation.original_target
                )
                for reactant_prediction in reactant_predictions
            ]
        else:
            df_dict['pred_tanimoto_to_target'] = [-1]*len(reactant_predictions)
        df_dict['true_tanimoto_to_starting_material'] = [
            true_tanimoto_to_starting_material
        ]*len(reactant_predictions)
        df_dict['pred_tanimoto_to_starting_material'] = [
            get_tanimoto(
                reactant_prediction,
                config.single_step_evaluation.original_starting_material
            )
            for reactant_prediction in reactant_predictions
        ]
        df_dict['topk_detailed'] = topk
        df_dict['topk'] = [
            compare_reactant_smiles(true_reactants, reactant_prediction)
            for reactant_prediction in reactant_predictions
        ]
        df_dict['round_trip_accuracy_detailed'] = round_trip_accuracy
        df_dict['round_trip_accuracy'] = [
            product_smi in round_trip_results_for_one_pred_reactant 
            for round_trip_results_for_one_pred_reactant in round_trip_results
        ]
        df_dict['rxn_insight_info'] = rxn_insight_info
        df_dict['all_pred_reactants_are_bbs'] = [
            all(m in bbs for m in reactant_prediction.split('.'))
            for reactant_prediction in reactant_predictions
        ]
    return df_dict

def get_round_trip_results(predictions, config):
    forward_model_dir = os.path.join(PROJECT_ROOT,
                                     'checkpoints',
                                      config.single_step_evaluation.forward_model_dir)
    # NOTE: this is a forward model, so we input the reactants as our product, and the product in the reactants field
    print(f'============ get_round_trip_results,'+\
          f'config.single_step_evaluation.num_results_forward: {config.single_step_evaluation.num_results_forward},'+\
          f' config.single_step_evaluation.num_augmentations: {config.single_step_evaluation.num_augmentations}')
    root_aligned_forward = RootAlignedForwardModel(use_cache=True,
                                                    num_augmentations=config.single_step_evaluation.num_augmentations,
                                                    default_num_results=config.single_step_evaluation.num_results_forward, # 10
                                                    model_dir=forward_model_dir,
                                                    config=config)
    #round_trip_results = root_aligned_forward(results)
    # process predictions in batches
    batch_size = 1
    round_trip_results = []
    print(f'============ get_round_trip_results, len(predictions): {len(predictions)}, batch_size: {batch_size}')
    for i in range(0, len(predictions), batch_size):
        batch = predictions[i:i+batch_size]
        out = root_aligned_forward([Molecule(mol) for mol in batch], 
                                   num_results=config.single_step_evaluation.num_results_forward)
        round_trip_results.extend(out)
    round_trip_results_smiles = turn_results_to_mol_smiles(round_trip_results)
    return round_trip_results_smiles


def compute_topk_accuracy(results, true_reactants, topk={1: 0, 3: 0, 5: 0, 10: 0}):
    '''
    results: list of strings
    true_reactants: string representing true reactants
    topk: dict of ints
    '''
    for key in topk.keys():
        if true_reactants in results[:key]:
            topk[key] += 1
            break
    return topk

class_to_idx = {
    'Heteroatom Alkylation and Arylation': 0,
    'Acylation': 1,
    'C-C Coupling': 2,
    'Aromatic Heterocycle Formation': 3,
    'Protection': 4,
    'Deprotection': 5,
    'Reduction': 6,
    'Oxidation': 7,
    'Functional Group Addition': 8,
    'Functional Group Interconversion': 9,
    'Miscellaneous': 10,
    '-1': -1
}

def remove_dative_bonds(reactant_predictions):
    out_reactant_predictions = []
    for reactant_prediction in reactant_predictions:
        mol = Chem.MolFromSmiles(reactant_prediction)
        for bond in mol.GetBonds():
            if bond.GetBondType() == Chem.BondType.DATIVE:
                bond.SetBondType(Chem.BondType.SINGLE)
        out_reactant_predictions.append(Chem.MolToSmiles(mol))
    return out_reactant_predictions

def clear_atom_map(smi):
    """
    Clear atom map numbers from and canonicalize a SMILES string.

    Args:
        smi (str): SMILES string to clear atom map numbers from.

    Returns:
        str: Canonicalized SMILES string with atom map numbers cleared
    """
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        #print(f'========== Invalid molecule: {smi}')
        return smi
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(0)
    return Chem.CanonSmiles(Chem.MolToSmiles(mol))


def compare_reactant_smiles(smiles1, smiles2):
    set1 = set([clear_atom_map(sm) for sm in smiles1.split('.')])
    set2 = set([clear_atom_map(sm) for sm in smiles2.split('.')])
    return set1 == set2

def get_tanimoto(main_target, sm):
    ''' 
    Get tanimoto distance between main target and starting material.
    Args:
        main_target (str): main target smiles
        starting_material (list): list of starting material smiles
    Returns:
        list: list of tuples of (main_target, starting_material, tanimoto_distance)
    '''
    mol1 = Chem.MolFromSmiles(sm)
    mol2 = Chem.MolFromSmiles(main_target)
    # Generate fingerprints
    fp1 = Chem.RDKFingerprint(mol1)
    fp2 = Chem.RDKFingerprint(mol2)
    # Calculate Tanimoto similarity
    tanimoto_distance = DataStructs.TanimotoSimilarity(fp1, fp2)
    return tanimoto_distance

def get_results_and_evaluate_for_one_molecule(
    config,
    conditional_starting_material=None,
    conditional_target=None
):
    '''
        This function runs the single step evaluation for one molecule.
    '''
    config.classifier_guidance.target_class_index = config.single_step_evaluation.rxn_class
    target_property = config.single_step_evaluation.rxn_class
    reactant_predictions = get_retrosynthetic_results(
        config,
        config.single_step_evaluation.product_smi,
        conditional_starting_material=conditional_starting_material,
        conditional_target=conditional_target
    )
    reactant_predictions = reactant_predictions[0]
    df = single_step_evaluation_of_one_molecule(
        config=config,
        reactant_predictions=reactant_predictions,
        product_smi=config.single_step_evaluation.product_smi,
        true_reactants=config.single_step_evaluation.true_reactants,
        target_property=target_property
    )
    return df


@dataclass
class ReactionData:
    """Container for reaction data from batch."""
    reactants: str
    product: str
    class_idx: str
    original_starting_material: Optional[str]
    conditional_starting_material: str
    original_target: Optional[str]
    conditional_target: Optional[str]
    batch_index: int  # Original index in dataset


def parse_batch_to_reaction_data(
    batch: List[tuple],
    start_idx: int
) -> List[ReactionData]:
    """
        Convert raw batch tuples to structured ReactionData objects.
    """
    return [
        ReactionData(
            reactants=reactants,
            product=product,
            class_idx=class_idx,
            original_starting_material=original_starting_material,
            conditional_starting_material=conditional_starting_material,
            original_target=original_target,
            conditional_target=conditional_target,
            batch_index=start_idx + i
        )
        for i, (reactants, product, class_idx, original_starting_material,
                conditional_starting_material, original_target,
                conditional_target) in enumerate(batch)
    ]

def update_config_for_reaction(config, reaction: ReactionData) -> None:
    """Update config object with reaction data for evaluation."""
    config.single_step_evaluation.product_smi = reaction.product
    config.single_step_evaluation.true_reactants = reaction.reactants
    config.single_step_evaluation.original_starting_material = reaction.original_starting_material
    config.single_step_evaluation.original_target = reaction.original_target
    config.classifier_guidance.target_class_index = reaction.class_idx
    config.single_step_evaluation.rxn_class = reaction.class_idx

def flatten_predictions(
    df_dict: Dict[str, Any],
    reaction: ReactionData,
    separator: str
) -> List[Dict[str, Any]]:
    """
    Flatten prediction dictionary into list of row dictionaries.
    
    Args:
        df_dict: Dictionary with prediction results (may contain lists or scalars)
        reaction: ReactionData object with metadata
        separator: Starting material separator string
        
    Returns:
        List of dictionaries, one per prediction
    """
    if not df_dict or 'reactant_predictions' not in df_dict:
        return []
    
    num_predictions = len(df_dict['reactant_predictions'])
    
    # Add metadata fields
    metadata = {
        'starting_material_separator': separator,
        'original_starting_material': reaction.original_starting_material,
        'original_target': reaction.original_target,
        'batch_index': reaction.batch_index
    }
    
    rows = []
    for j in range(num_predictions):
        row = {}
        
        # Add prediction data
        for key, value in df_dict.items():
            if isinstance(value, list) and len(value) == num_predictions:
                row[key] = value[j]
            else:
                row[key] = value  # Scalar value
        
        # Add metadata
        row.update(metadata)
        rows.append(row)
    
    return rows


def get_batch(config):
    '''
        This function gets a batch of molecules from the dataset.
    '''
    # get the batch of molecules
    # read the test file
    if config.single_step_evaluation.data_dir=='uspto_50k/processed' or config.single_step_evaluation.data_dir=='uspto_50k_debug/processed':
        df = pd.read_csv(os.path.join(PROJECT_ROOT, 'data', config.single_step_evaluation.data_dir, config.single_step_evaluation.subset))
        if 'conditional_starting_material' not in df.columns:
            if config.classifier_guidance.with_reactants_as_starting_material:
                original_starting_material = df['reactants>reagents>production'].apply(lambda x: x.split('>>')[0])
            else:
                original_starting_material = df.apply(lambda x: choose_closest_starting_material(\
                                                        x['reactants>reagents>production'].split('>>')[0].split('.'), \
                                                        x['reactants>reagents>production'].split('>>')[1]), axis=1)
                
            df['original_starting_material'] = original_starting_material
            df['conditional_starting_material'] = df['original_starting_material'].apply(lambda x: '<s>'+x+config.classifier_guidance.dataset.separator)
        if 'conditional_target' not in df.columns:
            # TODO: this is not used at all now
            df['original_target'] = df['sorted_cano_products']
            df['conditional_target'] = df['original_target'].apply(lambda x: config.classifier_guidance.dataset.separator+x+'</s>')
        batch = df.iloc[config.single_step_evaluation.start_idx:config.single_step_evaluation.end_idx]\
                .apply(lambda x: (
                    x['sorted_cano_reactants'],
                    x['sorted_cano_products'],
                    x['reaction_type'],
                    x['original_starting_material'],
                    x['conditional_starting_material'],
                    x['original_target'],
                    x['conditional_target']
                ), axis=1).tolist()
    elif config.single_step_evaluation.data_dir=='uspto_190/first_reactions' or config.single_step_evaluation.data_dir=='uspto_190/train_unique_reactions':
        df = pd.read_csv(os.path.join(PROJECT_ROOT, 
                            'data', 
                            config.single_step_evaluation.data_dir, 
                            config.single_step_evaluation.subset))
        # get conditional starting material for each target

        if 'main_target' not in df.columns:
            raise ValueError('main_target not in df.columns')

        if 'conditional_starting_material' in df.columns:
            raise ValueError('conditional_starting_material not in df.columns')

        df['original_target'] = df['main_target']
        df['conditional_target'] = df['original_target'].apply(lambda x: config.classifier_guidance.dataset.separator+x+'</s>', axis=1)
        df['original_starting_material'] = df['conditional_starting_material']
        df['conditional_starting_material'] = df.apply(lambda x: config.classifier_guidance.dataset.separator+x['original_starting_material']+'</s>', axis=1)

        batch = df.iloc[config.single_step_evaluation.start_idx:config.single_step_evaluation.end_idx]\
                .apply(lambda x: (x['reactant'], x['product'], x['rxn_insight_class'], \
                            x['original_starting_material'], x['conditional_starting_material'], \
                                 x['original_target'], x['conditional_target']), axis=1)\
                .tolist()
    elif config.single_step_evaluation.data_dir=='uspto_190/first_reactions_with_targets' \
        or config.single_step_evaluation.data_dir=='uspto_190/first_reactions_nonlinear_with_targets':
        df = pd.read_csv(os.path.join(PROJECT_ROOT, 
                                        'data', 
                                        config.single_step_evaluation.data_dir, 
                                        config.single_step_evaluation.subset))
        # get conditional starting material for each target
        if 'conditional_starting_material' not in df.columns:
            raise ValueError('conditional_starting_material not in df.columns')
        # TODO: change this to conditional_target for consistency
        if 'main_target' not in df.columns:
            raise ValueError('main_target not in df.columns')

        df['original_target'] = df['main_target']
        df['conditional_target'] = df.apply(lambda x: config.classifier_guidance.dataset.separator+x['original_target']+'</s>', axis=1)
        df['original_starting_material'] = df['conditional_starting_material']
        df['conditional_starting_material'] = df.apply(lambda x: config.classifier_guidance.dataset.separator+x['original_starting_material']+'</s>', axis=1)

        batch = df.iloc[config.single_step_evaluation.start_idx:config.single_step_evaluation.end_idx]\
                .apply(lambda x: (x['reactant'], x['product'], x['rxn_insight_class'], x['original_starting_material'],\
                     x['conditional_starting_material'], x['original_target'], x['conditional_target'] ), axis=1)\
                .tolist()
    elif config.single_step_evaluation.data_dir=='uspto_50k/no_solutions':
        df = pd.read_csv(os.path.join(PROJECT_ROOT, 'data', config.single_step_evaluation.data_dir, config.single_step_evaluation.subset))
        if 'conditional_starting_material' not in df.columns:
            if config.classifier_guidance.with_reactants_as_starting_material:
                conditional_starting_material = df['reactants>reagents>production'].apply(lambda x: '<s>'+x.split('>>')[0]+'</s>')
            else:
                conditional_starting_material = df.apply(lambda x: '<s>'+choose_closest_starting_material(\
                                                        x['reactants>reagents>production'].split('>>')[0].split('.'), \
                                                        x['reactants>reagents>production'].split('>>')[1])+config.classifier_guidance.dataset.separator, axis=1)
        
            df['conditional_starting_material'] = conditional_starting_material

        if 'conditional_target' not in df.columns:
            # TODO: this is not used at all now
            df['conditional_target']  = df.apply(lambda x: config.classifier_guidance.dataset.separator+x['product']+'</s>', axis=1)

        batch = df.iloc[config.single_step_evaluation.start_idx:config.single_step_evaluation.end_idx]\
                .apply(lambda x: (x['reaction'].split('>>')[0], x['reaction'].split('>>')[1], x['true_class'], \
                    x['conditional_starting_material'], x['conditional_target']), axis=1)\
                .apply(lambda x: (clear_atom_map(x[0]), clear_atom_map(x[1]), x[2], x[3], x[4]))\
                .tolist()
    elif config.single_step_evaluation.data_dir=='uspto_190/reactions_with_starting_material':
        df = pd.read_csv(os.path.join(PROJECT_ROOT, 'data', config.single_step_evaluation.data_dir, config.single_step_evaluation.subset))
        df['original_target'] = None
        df['conditional_target'] = None
        df['original_starting_material'] = df['conditional_starting_material']
        df['conditional_starting_material'] = df.apply(lambda x: '<s>'+x['original_starting_material']+'.', axis=1)
        batch = df.iloc[config.single_step_evaluation.start_idx:config.single_step_evaluation.end_idx]\
                .apply(lambda x: (x['reactant'], x['product'], x['rxn_insight_class'], x['original_starting_material'],\
                     x['conditional_starting_material'], x['original_target'], x['conditional_target'] ), axis=1)\
                .tolist()
    else:
        raise ValueError(f'Invalid data directory: {config.single_step_evaluation.data_dir}')
    return batch
   
@hydra.main(config_path='../configs', config_name='config.yaml')
def evaluate_single_step_model(config):
    '''
        This function runs the single step evaluation for one molecule 
        and processes and saves the results.
    '''
    # run single step evaluation for one molecule
    start_idx = config.single_step_evaluation.start_idx
    end_idx = config.single_step_evaluation.end_idx
    all_rows = []
    batch = get_batch(config)
    reactions = parse_batch_to_reaction_data(
        batch,
        start_idx
    )
    all_rows = []
    for reaction in reactions:
        print(f'Idx {reaction.batch_index}')
        print(f'conditional_starting_material: {reaction.conditional_starting_material}')
        print(f'conditional_target: {reaction.conditional_target}')
        # Update config with current reaction
        update_config_for_reaction(config, reaction)
        # Get predictions
        df_dict = get_results_and_evaluate_for_one_molecule(
            config,
            conditional_starting_material=reaction.conditional_starting_material,
            conditional_target=reaction.conditional_target
        )
        if not df_dict:
            print(f'No predictions found for {reaction.reactants}>>{reaction.product}, '
                  f'idx: {reaction.batch_index}')
            continue
        # Flatten predictions into rows
        rows = flatten_predictions(
            df_dict,
            reaction,
            config.classifier_guidance.dataset.separator
        )
        all_rows.extend(rows)
    df = pd.DataFrame(all_rows)
    out_dir = os.path.join(
        PROJECT_ROOT,
        'experiments', 
        config.general.experiment_group,
        config.general.experiment_params,
        config.general.experiment_name
    )
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(
        out_dir,
        f'start{start_idx}_end{end_idx}.csv'
    )
    df.to_csv(out_path, index=False)
    print(f'======== df saved to {out_path}')

if __name__ == '__main__':
    evaluate_single_step_model()
