import torch
import os
from syntheseus.search.node_evaluation.common import NoCacheNodeEvaluator
from multiguide.syntheseus.value_function.desp_synthetic_distance import SyntheticDistance, smiles_to_fp
from multiguide.syntheseus.value_function.retro_star_value_function import RetroStarVM
from multiguide.helpers import PROJECT_ROOT
from multiguide.property.property_predictor import PropertyPredictor
from multiguide.dataset.helpers import get_vocab_size_from_config, tokenize, turn_seq_to_ids, get_vocab_from_trained_model
from multiguide.dataset.helpers import get_tanimoto
import torch.nn.functional as F
from rdkit import Chem
from multiguide.dataset.helpers import get_sorted_cano_smiles, get_rxn_insight_info, class_to_idx

class ValueNodeWithPropertyEvaluator(NoCacheNodeEvaluator):
    def __init__(self,
                 config: dict,
                 **kwargs):
        super().__init__(**kwargs)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.config = config
        self.load_value_function()
        #self.load_property_predictor()
        print("Loaded retro value model!")

    def load_property_predictor(self):
        checkpoint_path = os.path.join(PROJECT_ROOT,
                                        'experiments', 
                                        self.config.classifier_guidance.experiment_name,
                                        'checkpoints',
                                        self.config.classifier_guidance.checkpoint_path)
        self.property_predictor_checkpoint = torch.load(checkpoint_path, map_location=self.device)
        #alphabet_size = get_vocab_size_from_config(self.config)
        vocab_size = len(get_vocab_from_trained_model(self.config.classifier_guidance.onmt_checkpoint_path))
        self.property_predictor = PropertyPredictor(self.config, alphabet_size=vocab_size)
        self.property_predictor.load_state_dict(self.property_predictor_checkpoint["model_state_dict"])
        self.property_predictor.to(self.device)
        self.property_predictor.eval()
    
    def load_value_function(self):
        path = os.path.join(PROJECT_ROOT, 
                            'checkpoints', 
                            self.config.value_function.model_path)
        checkpoint = torch.load(path, map_location=self.device, weights_only=False)
        # loading the value function (i.e. the RetroCost in TangoStar)
        print(f'======= using value function {self.config.value_function.type}')
        if self.config.value_function.type == 'retro_star':
            # TODO: maybe get the params from the checkpoint
            self.model = RetroStarVM(n_layers=self.config.value_function.n_layers, 
                                     fp_dim=self.config.value_function.fp_dim, 
                                     latent_dim=self.config.value_function.latent_dim, 
                                     dropout_rate=self.config.value_function.dropout_rate, 
                                     device=self.device).to(self.device)
            state_dict = checkpoint
        elif self.config.value_function.type == 'desp':
            pretrain_args = checkpoint["args"]
            pretrain_args.output_dim = 1
            state_dict = checkpoint["state_dict"]
            state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
            self.model = SyntheticDistance(pretrain_args).to(self.device)
        else:
            raise ValueError(f"Value function type {self.config.value_function.type} not found")
        
        # Load the checkpoint
        self.model.load_state_dict(state_dict)
        self.model.eval()

    def _evaluate_nodes(self, nodes, graph=None):
        # get model estimates for each node
        costs = [
            self.predict(
                node.mol.smiles,
                node.mol.metadata['true_prev_reaction_type'],
                node.mol.metadata['sample_prev_reaction_type'],
                node.mol.metadata['true_tanimoto'],
                node.mol.metadata['sample_tanimoto'],
                node.mol.metadata['sample_max_tanimoto'],
                node.mol.metadata['sample_max_tanimoto_reactant'],
                node.mol.metadata['max_tanimoto_property_weight']
            ) for node in nodes
        ]
        return costs

    def get_property_reward_old(self, property_pred):
        """Convert any property to a reward (higher is better)"""
        # TODO: can remove sigmoid as in retrostar (think some more about whether it's needed or not)
        if self.config.classifier_guidance.property in ['toxicity', 'sa_score', 'np_score']:
            # For these properties, lower values are better
            return F.sigmoid(self.config.classifier_guidance.sigmoid_steepness * 
                            (self.config.classifier_guidance.prediction_threshold - property_pred))
        elif self.config.classifier_guidance.property == 'yield':
            # For yield, higher values are better
            return F.sigmoid(self.config.classifier_guidance.sigmoid_steepness * 
                            (property_pred - self.config.classifier_guidance.prediction_threshold))
        else:
            raise ValueError(f"Property {self.config.classifier_guidance.property} not found")

    def get_property_reward(self, property_pred):
        if self.config.classifier_guidance.property == 'reaction_type':
            return property_pred.argmax().item() == reaction_types.item()
        else:
            raise ValueError(f"Property {self.config.classifier_guidance.property} not supported for filtering")

    def get_distance_to_threshold_based_on_property(self, scores_values):
        if self.config.classifier_guidance.property == 'toxicity':
            return self.config.classifier_guidance.prediction_threshold - scores_values # positive difference => higher sigmoid/reward => lower cost
        elif self.config.classifier_guidance.property == 'yield':
            return scores_values - self.config.classifier_guidance.prediction_threshold # positive is higher yield
        elif self.config.classifier_guidance.property == 'sa_score':
            return self.config.classifier_guidance.prediction_threshold - scores_values # positive is lower sa_score
        elif self.config.classifier_guidance.property == 'np_score':
            return self.config.classifier_guidance.prediction_threshold - scores_values # positive is higher np_score
        else:
            raise ValueError(f"Property {self.config.classifier_guidance.property} not found")
        
    def predict(
        self, 
        target, 
        true_prev_reaction_type, 
        sample_prev_reaction_type, 
        true_tanimoto, 
        sample_tanimoto, 
        sample_max_tanimoto, 
        sample_max_tanimoto_reactant, 
        max_tanimoto_property_weight,
        as_item=True
    ):
        """
        Predict the synthetic cost of 'target'.

        Args:
            target (str): target molecule SMILES

        Returns:
            float: synthetic distance
        """
        target_fp = smiles_to_fp(target, fp_size=2048).float().unsqueeze(0)
        target_fp = target_fp.to(self.device)
        dist = self.model(target_fp)
        # all_chars_path = os.path.join(PROJECT_ROOT, 
        #                                'data', 
        #                                self.config.classifier_guidance.dataset.vocab_file)
        # targets_id = tokenize(target, all_chars_path)
        if self.config.classifier_guidance.property == 'reaction_type':
            # reaction_smiles = get_sorted_cano_smiles([target]) + '>>' + Chem.MolToSmiles(Chem.MolFromSmiles(parent_product))
            # reaction_info = get_rxn_insight_info(reaction_smiles)
            # sample_reaction_type = class_to_idx[reaction_info['CLASS']] if reaction_info is not None else -1
            property_cost = int(true_prev_reaction_type != sample_prev_reaction_type)
            # if property_cost == 0: # the two are equal
            #     print('*'*100)
            #     print(f'Sample reaction type is equal to true reaction type in guided value function.')
            #     print(f'target: {target}')
            #     print(f'reaction_types: {true_prev_reaction_type}')
            #     print(f'sample_reaction_type: {sample_prev_reaction_type}')
            #     print(f'true_tanimoto: {true_tanimoto}')
            #     print(f'sample_tanimoto: {sample_tanimoto}')
            #     print('*'*100)
            #property_cost = 1.0 - reward
            # targets_id = turn_seq_to_ids(target, self.config.classifier_guidance.onmt_checkpoint_path)
            # if targets_id is not None:
            #     targets_id = targets_id.unsqueeze(0).to(self.device)
            #     with torch.no_grad():
            #         property_pred = self.property_predictor(targets_id)
            #         reward = F.softmax(property_pred.squeeze(0), dim=-1)[reaction_types.item()].item()
            #         # property_pred = property_pred_normalized * self.property_predictor_checkpoint['target_std'] \
            #         #     + self.property_predictor_checkpoint['target_mean']
            #         # Get reward (higher = better) and convert to cost (lower = better)
            #         #reward = self.get_property_reward(property_pred)
            #         property_cost = 1.0 - reward  # Convert reward to cost
            #         # Combine the distance and property cost
            # else:
            #     property_cost = 1.0
            property_cost = self.config.classifier_guidance.eval.reaction_type_weight * property_cost
        elif self.config.classifier_guidance.property=='tanimoto_like_tango':
            # TODO: could make smthg more complex to make sure we match the tanimoto increase of true routes.
            # NOTE: this automatically chooses the most similar reactant.
            property_cost = 1 - sample_tanimoto
            print('*'*100)
            print(f'target: {target}')
            print(f'sample_tanimoto: {sample_tanimoto}')
            print(f'true_tanimoto: {true_tanimoto}')
            print('*'*100)
            # tanimoto_similarity = get_tanimoto(target, starting_material)
            # # favor most similar reactant
            # property_cost = 1 - tanimoto_similarity
            property_cost = self.config.classifier_guidance.eval.tanimoto_weight * property_cost
        elif self.config.classifier_guidance.property=='max_tanimoto':
            property_cost = 1 - sample_max_tanimoto
            print('*'*100)
            print(f'target: {target}')
            print(f'sample_tanimoto: {sample_tanimoto}')
            print(f'sample_max_tanimoto_reactant: {sample_max_tanimoto_reactant}')
            print(f'sample_max_tanimoto: {sample_max_tanimoto}')
            print(f'max_tanimoto_property_weight: {max_tanimoto_property_weight}')
            print('*'*100)
            property_cost = max_tanimoto_property_weight * property_cost
        elif self.config.classifier_guidance.property=='tanimoto_and_reaction_type':
            reaction_type_cost = (sample_prev_reaction_type != true_prev_reaction_type).float()
            tanimoto_cost = 1 - sample_tanimoto
            print('*'*100)
            print(f'target: {target}')
            print(f'sample_prev_reaction_type: {sample_prev_reaction_type}')
            print(f'true_prev_reaction_type: {true_prev_reaction_type}')
            print(f'reaction_type_cost: {reaction_type_cost}')
            print(f'sample_tanimoto: {sample_tanimoto}')
            print(f'true_tanimoto: {true_tanimoto}')
            print(f'tanimoto_cost: {tanimoto_cost}')
            print('*'*100)
            property_cost = self.config.classifier_guidance.eval.reaction_type_weight*reaction_type_cost
            property_cost += self.config.classifier_guidance.eval.tanimoto_weight*tanimoto_cost
        else:
            raise ValueError(f"Property {self.config.classifier_guidance.property} not supported for guided score")
        # print(f'======= property cost: {property_cost}')
        # print(f'======= dist: {dist}')
        # print(f'======= property_weight: {self.config.classifier_guidance.eval.property_weight}')
        combined_pred = self.config.classifier_guidance.eval.original_score_weight * dist + property_cost
        if as_item:
            return combined_pred.item()
        else:
            return combined_pred