import os
import torch
import math
import logging
logger = logging.getLogger(__name__)

# NOTE: Collection imported here instead of from collections.abc
# to make casting work for python <3.9
from typing import (
    Collection,
    Optional,
    Sequence,
    cast,
    List,
    TypeVar
)
from collections import defaultdict
import torch.nn.functional as F
from syntheseus.search.graph.and_or import ANDOR_NODE, AndNode, AndOrGraph, OrNode
from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch
from syntheseus.interface.reaction import SingleProductReaction
from syntheseus.search.graph.node import BaseGraphNode
from syntheseus.search.graph.base_graph import RetrosynthesisSearchGraph
from syntheseus.search.node_evaluation.base import BaseNodeEvaluator, NoCacheNodeEvaluator
from syntheseus.search.algorithms.best_first.base import PriorityQueue

from multiguide.property.property_predictor import PropertyPredictor
from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import tokenize, get_vocab_size_from_config, get_vocab_from_trained_model
from multiguide.dataset.helpers import turn_seq_to_ids, get_tanimoto, get_rxn_insight_info, class_to_idx
from multiguide.dataset.helpers import get_sorted_cano_smiles

GraphType = TypeVar("GraphType", bound=RetrosynthesisSearchGraph)
AlgReturnType = TypeVar("AlgReturnType")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MolIsPurchasableGuidedCost(NoCacheNodeEvaluator[OrNode]):
    def _evaluate_nodes(  # type: ignore[override]
        self,
        nodes: Sequence[OrNode],
        graph: Optional[AndOrGraph] = None,
    ) -> list[float]:
        #return [0.0 if node.mol.metadata.get("is_purchasable") else math.inf for node in nodes]
        return [math.inf for _ in nodes]

class RetroStarSearchWithPropertyFilter(
    RetroStarSearch
):
    def __init__(
        self,
        config: dict,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.config = config
        # if self.config.search.filtered:
        #     self.load_property_predictor()
        self.consumed_types_per_depth = defaultdict(set)

    def load_property_predictor(self):
        checkpoint_path = os.path.join(PROJECT_ROOT,
                                        'experiments', 
                                        self.config.classifier_guidance.experiment_name,
                                        'checkpoints',
                                        self.config.classifier_guidance.checkpoint_path)
        self.property_predictor_checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        #vocab_size = get_vocab_size_from_config(self.config)
        vocab_size = len(get_vocab_from_trained_model(self.config.classifier_guidance.onmt_checkpoint_path))
        self.property_predictor = PropertyPredictor(self.config, 
                                                    alphabet_size=vocab_size)
        self.property_predictor = self.property_predictor.to(device)
        self.property_predictor.load_state_dict(self.property_predictor_checkpoint["model_state_dict"])
        self.property_predictor.eval()  

    def get_distance_to_threshold_based_on_property(self, scores_values):
        if self.config.classifier_guidance.property == 'toxicity':
            return self.config.classifier_guidance.prediction_threshold - scores_values # positive is less toxic
        elif self.config.classifier_guidance.property == 'yield':
            return scores_values - self.config.classifier_guidance.prediction_threshold # positive is higher yield
        elif self.config.classifier_guidance.property == 'sa_score':
            return self.config.classifier_guidance.prediction_threshold - scores_values # positive is lower sa_score
        elif self.config.classifier_guidance.property == 'np_score':
            return self.config.classifier_guidance.prediction_threshold - scores_values # positive is higher np_score
        else:
            raise ValueError(f"Property {self.config.classifier_guidance.property} not found")
        
    def has_acceptable_property_value_old(self, rxn: SingleProductReaction) -> bool:
        # check if the reaction has an acceptable property value
        # get only the reactants if full reaction
        smiles = '.'.join([m.smiles for m in rxn.reactants])
        all_chars_path = os.path.join(PROJECT_ROOT, 
                                       'data', 
                                       self.config.classifier_guidance.dataset.vocab_file)
        reactants_id = tokenize(smiles, all_chars_path)
        reactants_id = reactants_id.unsqueeze(0).to(device)
        property_pred = self.property_predictor(reactants_id)
        property_pred_normalized = property_pred * self.property_predictor_checkpoint['target_std'] \
            + self.property_predictor_checkpoint['target_mean']
        # TODO: fix this
        # distance_to_threshold = self.get_distance_to_threshold_based_on_property(property_pred_normalized)
        # input_val = self.config.classifier_guidance.sigmoid_steepness * (distance_to_threshold)
        # classifier_scores = F.logsigmoid(input_val) # log(sigmoid(input_val))
        
        if self.config.classifier_guidance.property == 'toxicity': # lower is better
            return property_pred_normalized.item() < self.config.classifier_guidance.prediction_threshold
        elif self.config.classifier_guidance.property == 'yield': # higher is better
            return property_pred_normalized.item() > self.config.classifier_guidance.prediction_threshold
        elif self.config.classifier_guidance.property == 'sa_score': # lower is better
            return property_pred_normalized.item() < self.config.classifier_guidance.prediction_threshold
        elif self.config.classifier_guidance.property == 'np_score': # lower is better
            return property_pred_normalized.item() < self.config.classifier_guidance.prediction_threshold
        else:
            raise ValueError(f"Property {self.config.classifier_guidance.property} not found")
    
    def has_acceptable_property_value(
        self, 
        rxn: SingleProductReaction, 
        true_reaction_type: int, 
        starting_material: str
    ) -> bool:
        # check if the reaction has an acceptable property value
        # get only the reactants if full reaction
        #reactants_id = tokenize(smiles, all_chars_path)
        # reactants_id = turn_seq_to_ids(smiles, self.config.classifier_guidance.onmt_checkpoint_path)
        # reactants_id = reactants_id.unsqueeze(0).to(device)
        # property_pred = self.property_predictor(reactants_id)
        reactant_smiles = get_sorted_cano_smiles([m.smiles for m in rxn.reactants])
        product_smiles = get_sorted_cano_smiles([s.smiles for s in rxn.products])
        reaction_smiles = reactant_smiles+'>>'+product_smiles
        reaction_info = get_rxn_insight_info(reaction_smiles)
        reaction_type = class_to_idx[reaction_info['CLASS']] if reaction_info is not None else -1
        tanimoto_similarity = get_tanimoto(starting_material, reactant_smiles)
        max_tanimoto = 0
        max_tanimoto_reactant = ''
        for reactant in reactant_smiles.split('.'):
            tanimoto = get_tanimoto(starting_material, reactant)
            if tanimoto > max_tanimoto:
                max_tanimoto = tanimoto
                max_tanimoto_reactant = reactant
        rxn.metadata['true_reaction_type'] = true_reaction_type
        rxn.metadata['sample_reaction_type'] = reaction_type
        rxn.metadata['sample_tanimoto'] = tanimoto_similarity
        rxn.metadata['sample_max_tanimoto'] = max_tanimoto
        rxn.metadata['sample_max_tanimoto_reactant'] = max_tanimoto_reactant
        if self.config.classifier_guidance.property == 'reaction_type':
            return reaction_type == true_reaction_type
            #return True
            # property_likelihood = F.softmax(property_pred.squeeze(0), dim=-1)[reaction_types.item()].item()
            # return property_likelihood > self.config.classifier_guidance.prediction_threshold
            #return property_pred.argmax().item() == reaction_types.item()
        elif self.config.classifier_guidance.property == 'tanimoto_like_tango' or self.config.classifier_guidance.property == 'max_tanimoto':
            return tanimoto_similarity > self.config.classifier_guidance.prediction_threshold
        elif self.config.classifier_guidance.property == 'tanimoto_and_reaction_type':
            return tanimoto_similarity > self.config.classifier_guidance.prediction_threshold and reaction_type == true_reaction_type
        else:
            raise ValueError(f"Property {self.config.classifier_guidance.property} not supported for filtering")

    def _filter_reactions(
        self, 
        reactions: Sequence[SingleProductReaction],
        node: BaseGraphNode, 
        graph: AndOrGraph,
        reaction_types: torch.Tensor,
        starting_material: str
    ) -> List[SingleProductReaction]:
        # filter out reactions that contain the root molecule
        reactions = super()._filter_reactions(reactions, node, graph)
        # filter based on the property
        if self.config.search.filtered:
            # TODO: change this to include reaction type
            #reactions = [rxn for rxn in reactions if self.has_acceptable_property_value(rxn)]
            print(f'reactions before filtering: {len(reactions)}')
            reactions = [rxn for rxn in reactions if self.has_acceptable_property_value(rxn, reaction_types, starting_material)]
            print(f'reactions after filtering: {len(reactions)}')
        return reactions

    def get_true_information_for_or_node(self, mol, depth):
        # Try molecule lookup first
        # TODO: can probably remove this lookup since we are now using branching factor
        #print('='*100)
        mol_smiles = mol.smiles  # or canonical SMILES
        # print(f'mol_smiles: {mol_smiles}')
        # print(f'depth: {depth}')
        if mol_smiles in mol.metadata['ground_truth_data']['mol_to_rxn_type']:
            #print(f'Assigning target reaction type with ground truth matches: {mol_smiles}')
            reaction_type = mol.metadata['ground_truth_data']['mol_to_rxn_type'][mol_smiles]
            tanimoto = mol.metadata['ground_truth_data']['mol_to_tanimoto'][mol_smiles]
            ground_truth_reactants = mol.metadata['ground_truth_data']['mol_to_ground_truth_reactants'][mol_smiles]
            #print(f'reaction_type: {reaction_type}')
            return reaction_type, tanimoto, ground_truth_reactants
        # handle case where depth is not in the ground truth data
        if depth not in mol.metadata['ground_truth_data']['depth_to_types']:
            #print(f'depth exceeded')
            return -1, 0, ''
        # Fallback: pop unused type from depth if available
        #available = mol.metadata['ground_truth_data']['depth_to_types'][depth]
        # print('='*100)
        # print(f'depth: {depth}')
        # print(f'mol.metadata["branching_factor"]: {mol.metadata["branching_factor"]}')
        # print('='*100)
        types_at_depth = mol.metadata['ground_truth_data']['depth_to_types'][depth]
        tanimotos_at_depth = mol.metadata['ground_truth_data']['depth_to_tanimotos'][depth]
        ground_truth_reactants_at_depth = mol.metadata['ground_truth_data']['depth_to_ground_truth_reactants'][depth]
        if mol.metadata['branching_factor'] < len(types_at_depth):
            #print(f'Assigning target reaction type with branching factor matches: {mol.metadata["branching_factor"]}')
            reaction_type = types_at_depth[mol.metadata['branching_factor']]
            tanimoto = tanimotos_at_depth[mol.metadata['branching_factor']]
            ground_truth_reactants = ground_truth_reactants_at_depth[mol.metadata['branching_factor']]
            #print(f'reaction_type: {reaction_type}')
            return reaction_type, tanimoto, ground_truth_reactants
        #print(f'branching factor exceeded: {mol.metadata["branching_factor"]}, {len(types_at_depth)}')
        return -1, 0, '' # no guidance
    
    def get_reaction_type_for_and_node(self, node):
        '''
        Get the reaction type for an and node
        '''
        #reaction_smiles = node.mol.smiles
        reaction_smiles = get_sorted_cano_smiles([s.smiles for s in node.reaction.reactants])\
            +'>>'+get_sorted_cano_smiles([s.smiles for s in node.reaction.products])
        reaction_info = get_rxn_insight_info(reaction_smiles)
        reaction_type = class_to_idx[reaction_info['CLASS']] if reaction_info is not None else -1
        # print('='*100)
        # print(f'reaction_smiles in AndNode: {reaction_smiles}')
        # print(f'reaction_type: {reaction_type}')
        # print('='*100)
        return reaction_type

    def get_tanimoto_for_and_node(self, node):
        '''
        Get the tanimoto for an and node
        '''
        reactant_smiles = get_sorted_cano_smiles([s.smiles for s in node.reaction.reactants])
        starting_material = next(iter(node.reaction.products)).metadata['starting_material']
        tanimoto = get_tanimoto(starting_material, reactant_smiles)
        max_tanimoto = 0
        max_tanimoto_reactant = ''
        for reactant in reactant_smiles.split('.'):
            tanimoto = get_tanimoto(starting_material, reactant)
            if tanimoto > max_tanimoto:
                max_tanimoto = tanimoto
                max_tanimoto_reactant = reactant
        return tanimoto, max_tanimoto, max_tanimoto_reactant

    def expand_node(
        self, node: BaseGraphNode, graph: GraphType, force_expansion: bool = False
    ) -> Sequence[BaseGraphNode]:
        """
        In the default case, checks self.can_expand_node, and if this passes then the node is expanded with the reaction model.

        If force_expansion=True, then this check is skipped and the node is expanded regardless of whether it should be.
        """
        print('*'*100)
        print(f'expanding node.mol.smiles: {node.mol.smiles}')
        print(f'node.depth: {node.depth}')
        print(f'previous synth step reaction type: node.mol.metadata["sample_prev_reaction_type"]: {node.mol.metadata["sample_prev_reaction_type"]}')
        print(f'previous synth step reaction type: node.mol.metadata["true_prev_reaction_type"]: {node.mol.metadata["true_prev_reaction_type"]}')
        print(f'target next synth step reaction type: node.mol.metadata["reaction_type_to_synthesize"]: {node.mol.metadata["reaction_type_to_synthesize"]}')
        print(f'node.mol.metadata["branching_factor"]: {node.mol.metadata["branching_factor"]}')
        print(f'node.mol.metadata["parent_product"]: {node.mol.metadata["parent_product"]}')
        print(f'node.mol.metadata["starting_material"]: {node.mol.metadata["starting_material"]}')
        print(f'node.mol.metadata["max_tanimoto_property_weight"]: {node.mol.metadata["max_tanimoto_property_weight"]}')
        print('*'*100)
        if force_expansion or self.can_expand_node(node, graph) and (self.config.classifier_guidance.enforce_starting_material_at_depth is None or node.depth<=self.config.classifier_guidance.enforce_starting_material_at_depth):
            # Get molecules to expand
            # NOTE: seems like this is always a single molecule, at least for retrostar
            # see _get_mols_to_expand in base.py, AndOrSearchAlgorithm
            mols = list(self._get_mols_to_expand(node, graph))
            # Optionally terminate without expansion if there are no molecules to expand
            if len(mols) == 0:
                return list()
            # NOTE: this means the reaction type will also have a single value for this node
            #reaction_types = torch.tensor([self.get_reaction_type_for_node(m, depth=node.depth) for m in mols]).to(device)
            # node.mol.metadata['reaction_types'] = reaction_types
            # NOTE: do this here after we obtain the depth info from the search algorithm & before expanding this particular node
            # node.mol.metadata['depth'] = node.depth
            # branching_factor = node.mol.metadata['branching_factor']
            # depth_to_ground_truth_reactants = node.mol.metadata['ground_truth_data']['depth_to_ground_truth_reactants']
            # if node.depth in depth_to_ground_truth_reactants and branching_factor < len(depth_to_ground_truth_reactants[node.depth]):
            #     node.mol.metadata['ground_truth_reactants'] = depth_to_ground_truth_reactants[node.depth][branching_factor]
            # else:
            #     print(f'depth {node.depth} or branching factor {branching_factor} not found')
            #     node.mol.metadata['ground_truth_reactants'] = ''
            # if -1 in reaction_types:
            #     print(f'reaction_types: {reaction_types}')
            #     exit()
            #print(f'reaction_types: {reaction_types}')
            #reaction_types = torch.tensor([1] * len(mols)).to(device)

            # Get reactions for each of these molecules
            if self.config.single_step_model.model_type == 'rootaligned':
                reaction_types = torch.tensor([node.mol.metadata['reaction_type_to_synthesize']]).to(device)
                rxn_model_output = self.reaction_model(
                    mols, 
                    reaction_types=reaction_types, 
                    conditional_starting_materials=[node.mol.metadata['starting_material']]
                )
            else:
                rxn_model_output = self.reaction_model(mols)

            # Filter reactions to remove unwanted ones
            filtered_rxn_list = [
                self._filter_reactions(rxn_list, node, graph, node.mol.metadata['reaction_type_to_synthesize'], node.mol.metadata['starting_material']) for rxn_list in rxn_model_output
            ]

            # Add new nodes to the graph
            new_nodes: list[BaseGraphNode] = list(
                graph.expand_with_reactions(
                    [rxn for rxn_list in filtered_rxn_list for rxn in rxn_list],
                    node,
                    ensure_tree=not self.unique_nodes,
                )
            )
            # add ground truth data to the nodes
            # NOTE: this code assumes new_nodes have the following order: reaction node followed by its reactant nodes
            # AndNode, OrNode, AndNode, OrNode, OrNode, AndNode...
            starting_material = node.mol.metadata['starting_material']
            original_tanimoto_property_weight = node.mol.metadata['max_tanimoto_property_weight']
            tanimoto_property_weight = original_tanimoto_property_weight
            for new_node in new_nodes:
                if isinstance(new_node, AndNode):
                    branching_factor = 0
                    if 'sample_reaction_type' in new_node.reaction.metadata and 'sample_tanimoto' in new_node.reaction.metadata:
                        # if computed in filtering, use the values here
                        sample_reaction_type = new_node.reaction.metadata['sample_reaction_type']
                        sample_tanimoto = new_node.reaction.metadata['sample_tanimoto']
                        sample_max_tanimoto = new_node.reaction.metadata['sample_max_tanimoto']
                        sample_max_tanimoto_reactant = new_node.reaction.metadata['sample_max_tanimoto_reactant']
                    else:
                        sample_reaction_type = self.get_reaction_type_for_and_node(new_node)
                        sample_tanimoto, sample_max_tanimoto, sample_max_tanimoto_reactant = self.get_tanimoto_for_and_node(new_node)

                    # NOTE: if we found the sm, making guidance weight 0 (no need to guide this branch towards sm anymore)
                    # we use original_tanimoto_property_weight each time because each andnode defines a new branch
                    tanimoto_property_weight = int(sample_max_tanimoto_reactant!=starting_material)*original_tanimoto_property_weight
                elif isinstance(new_node, OrNode):
                    new_node.mol.metadata['ground_truth_data'] = node.mol.metadata['ground_truth_data']
                    # NOTE: right now this is assigned randomly,
                    # could think of a better way to make assigning branching factor ordered
                    new_node.mol.metadata['branching_factor'] = branching_factor
                    true_information = self.get_true_information_for_or_node(new_node.mol, depth=node.depth+2)
                    new_true_reaction_type = torch.tensor([true_information[0]]).to(device)
                    new_true_tanimoto = torch.tensor([true_information[1]]).to(device)
                    new_true_reactants = true_information[2]
                    new_node.mol.metadata['reaction_type_to_synthesize'] = new_true_reaction_type
                    # the reaction type of the parent node, which was the target reaction type producing this node
                    new_node.mol.metadata['true_prev_reaction_type'] = node.mol.metadata['reaction_type_to_synthesize']
                    new_node.mol.metadata['sample_prev_reaction_type'] = sample_reaction_type
                    new_node.mol.metadata['starting_material'] = node.mol.metadata['starting_material']
                    new_node.mol.metadata['parent_product'] = node.mol.smiles
                    new_node.mol.metadata['ground_truth_reactants'] = new_true_reactants
                    new_node.mol.metadata['sample_tanimoto'] = get_tanimoto(new_node.mol.metadata['starting_material'], new_node.mol.smiles)
                    new_node.mol.metadata['reaction_tanimoto'] = sample_tanimoto
                    new_node.mol.metadata['true_next_tanimoto'] = new_true_tanimoto
                    new_node.mol.metadata['true_tanimoto'] = node.mol.metadata['true_tanimoto']
                    new_node.mol.metadata['sample_max_tanimoto'] = sample_max_tanimoto
                    new_node.mol.metadata['sample_max_tanimoto_reactant'] = sample_max_tanimoto_reactant
                    new_node.mol.metadata['max_tanimoto_property_weight'] = tanimoto_property_weight
                    # ground_data = node.mol.metadata['ground_truth_data']
                    # depth = node.depth
                    # if branching_factor < len(ground_data['depth_to_ground_truth_reactants'][depth+2]):
                    #     new_node.mol.metadata['ground_truth_reactants'] = ground_data['depth_to_ground_truth_reactants'][depth+2][branching_factor]
                    # else:
                    #     print(f'branching factor {branching_factor} not found for depth {depth+2}')
                    #     # most likely because it's a starting material
                    #     new_node.mol.metadata['ground_truth_reactants'] = ''
                    branching_factor += 1
            # Return unique nodes, but in a consistent order
            return list(dict.fromkeys(new_nodes))
        else:
            return list()

