import torch
import os
from syntheseus.search.node_evaluation.common import NoCacheNodeEvaluator
from multiguide.syntheseus.value_function.desp_synthetic_distance import SyntheticDistance, smiles_to_fp
from multiguide.syntheseus.value_function.retro_star_value_function import RetroStarVM
from multiguide.helpers import PROJECT_ROOT

class ValueNodeEvaluator(NoCacheNodeEvaluator):
    '''
    Value node evaluator for the retro-star search.
    '''
    def __init__(
        self,
        config: dict,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        path = os.path.join(PROJECT_ROOT, 'checkpoints', config.value_function.model_path)
        checkpoint = torch.load(path, map_location="cpu", weights_only=False)
        # define model
        print(f'======= using value function {config.value_function.type}')
        if config.value_function.type=='retro_star':
            # TODO: maybe get the params from the checkpoint
            self.model = RetroStarVM(
                n_layers=config.value_function.n_layers,
                fp_dim=config.value_function.fp_dim,
                latent_dim=config.value_function.latent_dim,
                dropout_rate=config.value_function.dropout_rate,
                device=self.device
            ).to(self.device)
            state_dict = checkpoint
        elif config.value_function.type=='desp':
            pretrain_args = checkpoint["args"]
            pretrain_args.output_dim = 1
            state_dict = checkpoint["state_dict"]
            state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
            self.model = SyntheticDistance(pretrain_args).to(self.device)
        else:
            raise ValueError(f"Value function type {config.value_function.type} not found")
        # Load the checkpoint
        self.model.load_state_dict(state_dict)
        self.model.eval()
        print("Loaded retro value model!")

    def _evaluate_nodes(self, nodes, graph=None):
        # get model estimates for each node
        costs = [self.predict(node.mol.smiles) for node in nodes]
        return costs

    def predict(self, target, as_item=True):
        """
        Predict the synthetic cost of 'target'.
        Args:
            target (str): target molecule SMILES
        Returns:
            float: synthetic distance
        """
        try:
            target_fp = smiles_to_fp(target, fp_size=2048).float().unsqueeze(0)
            target_fp = target_fp.to(self.device)
            with torch.no_grad():
                dist = self.model(target_fp)
            if as_item:
                return dist.item()
            else:
                return dist
        except (ValueError, AttributeError, Exception) as e:
            # Invalid SMILES - return a very high cost to discourage this path
            print(f"Warning: Could not process SMILES '{target}': {e}")
            if as_item:
                return float('inf')  # or a large finite value like 999.0
            else:
                return torch.tensor([float('inf')], device=self.device)