# At the very top of your script, before any other imports
import sys
from unittest.mock import MagicMock
# Mock the problematic RDKit drawing modules
sys.modules['rdkit.Chem.Draw.rdMolDraw2D'] = MagicMock()
sys.modules['rdkit.Chem.Draw'] = MagicMock()
from typing import List
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdChemReactions as Reactions
from rdchiral.main import rdchiralRun
from rdchiral.initialization import rdchiralReactants, rdchiralReaction
import torch
import graphviz
import tempfile
import logging
from rdkit.Chem import Draw
import json
import logging
from rdkit.Chem import rdChemReactions

# Configure basic logging
logging.basicConfig(
    level=logging.INFO,
    #format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    force=True  # This overwrites any existing logging configuration
)

# Create a stream handler that writes to sys.stdout (notebook cell output)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)

# Get the root logger and add the handler
log = logging.getLogger()
log.addHandler(console_handler)

bb_mol2idx = "./data/canon_building_block_mol2idx_no_isotope.json"

def _draw_and_connect_children(
    parent_node, child, img_map, dot, temp_img_dir
):
    """
    Helper function to draw and connect the children nodes in the synthetic route visualization.

    Args:
        parent_node (str): The SMILES string of the parent node.
        child (dict): The dictionary representing the child node.
        img_map (dict): A dictionary to store the filepaths of the temporary molecule images.
        dot (graphviz.Digraph): The GraphViz object for rendering the visualization.
        temp_img_dir (str): The path to the temporary directory for storing molecule images.
    """
    with open(bb_mol2idx, "r") as f:
        building_blocks = json.load(f)

    child_node = child["smiles"]
    if child["type"] == "mol":
        mol = Chem.MolFromSmiles(child_node)
        escaped = child_node.replace("/", "_")
        file_path = f"{temp_img_dir}/{escaped}.png"
        Draw.MolToFile(mol, file_path, size=(200, 200))
        img_map[child_node] = file_path
        if child["mol_type"] == "starting":
            color = "plum1"
        elif (
            child["mol_type"] == "intermediate"
            and child["orientation"] == "top"
            and child_node not in building_blocks
        ):
            color = "royalblue"
        elif (
            child["mol_type"] == "intermediate"
            and child_node not in building_blocks
        ):
            color = "skyblue3"
        elif child["mol_type"] == "building" and child["orientation"] == "top":
            color = "springgreen4"
        else:
            color = "springgreen3"
        dot.node(
            child_node,
            label="",
            image=file_path,
            shape="box",
            color=color,
            penwidth="2",
        )
    elif child["type"] == "reaction":
        if child["orientation"] == "top":
            color = "lightgoldenrod1"
        else:
            color = "lightgoldenrod3"
        child_node = "rxn" + parent_node
        dot.node(
            child_node,
            label="",
            shape="box",
            style="rounded",
            color=color,
            penwidth="2",
        )
    else:
        raise TypeError("Child type not recognized")
    dot.edge(parent_node, child_node, color="darkgrey")
    if "children" in child:
        children = child["children"]
        for i, child in enumerate(children):
            log.info(f"Recursively drawing child {i} of {len(children)}")
            _draw_and_connect_children(
                child_node, child, img_map, dot, temp_img_dir
            )

def visualize_route(route, path):
    """
    Visualize the synthetic route and save the image to the specified path.

    Args:
        route (dict): The dictionary representing the synthetic route.
        path (str): The filename to save the visualization image (path + ".png")
    """
    log.info(f"Visualizing route...")
    dot = graphviz.Digraph(format="png")
    root_node = route["smiles"]
    mol = Chem.MolFromSmiles(root_node)
    img_map = {}
    with tempfile.TemporaryDirectory() as temp_img_dir:
        # Escape / characters in root_node
        escaped = root_node.replace("/", "_")
        file_path = f"{temp_img_dir}/{escaped}.png"
        Draw.MolToFile(mol, file_path, size=(200, 200))
        img_map[root_node] = file_path
        dot.node(
            root_node,
            label="",
            image=file_path,
            shape="rect",
            color="lightsalmon",
            penwidth="2",
        )
        for i, child in enumerate(route["children"]):
            log.info(f"Drawing child {i} of {len(route['children'])}")
            _draw_and_connect_children(
                root_node, child, img_map, dot, temp_img_dir
            )
        dot.render(path)
    return
    
def smiles_to_fp(smiles, fp_size=2048):
    """
    Convert a SMILES string to a fingerprint.
    Args:
        smiles (str): SMILES string
    Returns:
        np.array: fingerprint of the SMILES
    """
    mol = Chem.MolFromSmiles(smiles)
    fp = AllChem.GetMorganFingerprintAsBitVect(
        mol, radius=2, nBits=fp_size, useChirality=True
    )
    fp = torch.tensor(fp, dtype=torch.uint8)
    return fp


def template_to_fp(template):
    """
    Convert a template string to a fingerprint.
    Args:
        template (str): template
    Returns:
        np.array: fingerprint of the template
    """
    template_rxn = AllChem.ReactionFromSmarts(template)
    fp_type = Reactions.FingerprintType.names["AtomPairFP"]
    args = [False, 0.2, 10, 1, 2048, fp_type]
    params = Reactions.ReactionFingerprintParams(*args)
    template_fp = Reactions.CreateStructuralFingerprintForReaction(template_rxn, params)
    template_fp = torch.tensor(template_fp)
    return template_fp


def run_retro(product, template):
    """
    Run a reaction given the product and the template.
    Args:
        product (str): product
        template (str): template
    Returns:
        str: reactant SMILES string
    """
    reactants = template.split(">>")[0].split(".")
    if len(reactants) > 1:
        template = "(" + template.replace(">>", ")>>")
    template = rdchiralReaction(template)
    try:
        outputs = rdchiralRun(template, product)
    except Exception as e:
        print(f"Error {e} running retro reaction {template} on product {product}")
        return []
    result = []
    for output in outputs:
        result.append(output.split("."))
    return result


def run_unimolecular_reaction(reactant, template):
    """
    Run a reaction given the reactant and the template.
    Args:
        reactant (str): reactant
        template (str): template
    Returns:
        str: product SMILES string
    """
    template = "(" + template.replace(">>", ")>>")
    template = rdchiralReaction(template)
    outputs = rdchiralRun(template, reactant)
    result = []
    for output in outputs:
        if len(output.split(".")) == 1:  # should only be 1 product
            result.append(output)
    return result


def is_reactant_first(reactant, template):
    """
    Check if `reactant` is the first reactant in a bimolecular template.
    Args:
        reactant (Chem.Mol): reactant
        template (str): template
    Returns:
        bool: whether `reactant` is the first reactant
    """
    first_reactant = template.split(">>")[0].split(".")[0]
    pattern = Chem.MolFromSmarts(first_reactant)
    return reactant.HasSubstructMatch(pattern)


def is_reactant_second(reactant, template):
    """
    Check if `reactant` is the second reactant in bimolecular template.
    Args:
        reactant (Chem.Mol): reactant
        template (str): template
    Returns:
        bool: whether `reactant` is the second reactant
    """
    second_reactant = template.split(">>")[0].split(".")[1]
    pattern = Chem.MolFromSmarts(second_reactant)
    return reactant.HasSubstructMatch(pattern)


def flatten_output(outputs):
    """
    Postprocess the output of a reaction to remove duplicates and invalid SMILES.
    Args:
        outputs (list): list of products
    Returns:
        list: list of deduplicated valid products
    """
    products = []
    for product in outputs:
        if len(product) == 1:
            smiles = Chem.MolToSmiles(product[0])
            try:
                Chem.CanonSmiles(smiles)
                products.append(smiles)
            except Exception as e:
                print(e)
                pass
        else:
            # print("More than one product")
            pass
    return list(set(products))


def run_bimolecular_reaction(reactants, template):
    """
    Run a reaction with rdchiral given two reactants and the template.
    Args:
        reactants (list): list of reactants
        template (str): template
    Returns:
        str: product SMILES string
    """
    if len(reactants) != 2:
        raise ValueError("Bimolecular reaction requires two reactants!")
    # template = "(" + template.replace(">>", ")>>")
    reactants = rdchiralReactants(".".join(reactants))
    template = "(" + template.replace(">>", ")>>")
    try:
        template = rdchiralReaction(template)
    except Exception as e:
        print(f"Error {e} initializing template {template}")
        return []
    outputs = rdchiralRun(template, reactants)
    result = []
    for output in outputs:
        if len(output.split(".")) == 1:  # should only be 1 product
            result.append(output)
    return result


def get_valid_rxns(product, reactant, building_block, retro_template):
    product_rd = rdchiralReactants(product)
    retro_result = run_retro(product_rd, retro_template)
    # print("getting vlid rxns")
    # print(product)
    # print(retro_template)
    # print(retro_result)
    valid_reactions = []
    for result in retro_result:
        if result == [reactant, building_block] or result == [building_block, reactant]:
            valid_reactions.append((product, reactant, building_block))
        elif result == [reactant]:
            valid_reactions.append((product, reactant, None))
    return valid_reactions

def parse_diffalign_samples_output(samples: List[List[str]]) -> List[dict]:
    '''
    Parse the output of the DiffAlign samples function.
    Args:
        samples (list[list]): list of samples
    Returns:
        list of dicts, each representing a sample reaction. E.g.
        [
            {
                "rxn_smiles": "C.C",
                "score": 1.0,
                "template": "",
                "reactants": ["C", "C"]
            }
        ]
    '''
    parsed_samples = []
    for sample in samples:
        # try to parse reaction with rdkit
        try:
            rxn = rdChemReactions.ReactionFromSmarts(sample, useSmiles=True)
            products_mol = rxn.GetProducts()
            reactants_mol = rxn.GetReactants()
            # processing each molecule: use kekule + canonical form
            products_smiles = [Chem.MolToSmiles(mol, kekuleSmiles=False, canonical=True) for mol in products_mol]
            reactants_smiles = [Chem.MolToSmiles(mol, kekuleSmiles=False, canonical=True) for mol in reactants_mol]
            # reconstruct reaction
            rxn_smiles = '.'.join(reactants_smiles) + '>>' + '.'.join(products_smiles)
        except Exception as e:
            print(f"Error {e} parsing reaction {sample}")
            # error means reaction might be wrong, but it's still useful to get its reactants for further search
            reactants_smiles = sample.split('>>')[0].split('.')
            rxn_smiles = '.'.join(reactants_smiles) + '>>' + '.'.join(reactants_smiles)
        
        parsed_samples.append({
            "rxn_smiles": rxn_smiles,
            "score": 0.0,
            "template": '',
            "reactants": reactants_smiles
        })
    return parsed_samples

def parse_scored_samples_from_diffalign(samples: List[dict]) -> List[dict]:
    '''
    Parse the output of the DiffAlign eval function.
    
    Args:
        samples (list[dict]): list of samples in dict format. E.g.
        [{'P':
            [{'rcts': [...], 
            'prod': [...], 
            'elbo': 2.0716702938079834, 
            'loss_t': 0.0, 
            'loss_0': -2.071672201156616, 
            'sample_idx': 0, 
            'count': 1, 
            'weighted_prob': 1.0},
            ....
            ]
        }], with 'P' the product smiles
    Returns:
        list of dicts, each representing a sample reaction. E.g.
    '''
    sampled_reactions = list(samples.values())[0]
    parsed_samples = []
    for reaction_dict in sampled_reactions:
        try:
            # TODO: hmm keep this like this for now, but later might have to check how diffalign parses reactant mols, might have to account for the salts there
            reaction_smiles = '.'.join(reaction_dict['rcts']) + '>>' + '.'.join(reaction_dict['prod'])
            rxn = rdChemReactions.ReactionFromSmarts(reaction_smiles, useSmiles=True)
            products_mol = rxn.GetProducts()
            reactants_mol = rxn.GetReactants()
            score = reaction_dict['weighted_prob']
            products_smiles = [Chem.MolToSmiles(mol, kekuleSmiles=False, canonical=True) for mol in products_mol]
            reactants_smiles = [Chem.MolToSmiles(mol, kekuleSmiles=False, canonical=True) for mol in reactants_mol]
            rxn_smiles = '.'.join(reactants_smiles) + '>>' + '.'.join(products_smiles)
        except Exception as e:
            print(f"Error {e} parsing reaction {reaction_dict['rxn_smiles']}")
            reactants_smiles = reaction_dict['rcts']
            rxn_smiles = '.'.join(reaction_dict['rcts']) + '>>' + '.'.join(reaction_dict['prod'])
        
        parsed_samples.append({
            "rxn_smiles": rxn_smiles,
            "score": 0.,
            "template": '',
            "reactants": reactants_smiles
        })
    
    return parsed_samples
    