import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import time

from multiguide.dataset.helpers import smi_tokenizer
from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import get_vocab_from_trained_model, get_vocab_size_from_file, \
                            get_levenstein_similarity, get_token_prefix_similarity, \
                            get_token_prefix_similarity_vectorized
from multiguide.property.property_predictor import PropertyPredictor
from multiguide.dataset.helpers import turn_ids_to_seq, turn_seq_to_ids

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create a wrapper for the generator that adds your classifier guidance

class ClassifierGuidedGenerator(nn.Module):
    def __init__(self, original_generator, translator,
                 property_model, target_classes, guidance_scale=1.0, n_candidates_to_evaluate=10,
                 debug_classifier_scores=False, are_scores_close_tolerance=1e-1, 
                 are_scores_close_debug=True, config=None):
        super().__init__()
        self.original_generator = original_generator
        self.translator = translator
        self.property_model = property_model
        self.guidance_scale = guidance_scale
        self.debug_classifier_scores = debug_classifier_scores
        self.are_scores_close_tolerance = are_scores_close_tolerance
        self.are_scores_close_debug = are_scores_close_debug
        self.config = config
        self.batch_indices = None
        self.reaction_types = None
        checkpoint_path = os.path.join(PROJECT_ROOT,
                                        'experiments', 
                                        config.classifier_guidance.experiment_name,
                                        'checkpoints',
                                        config.classifier_guidance.checkpoint_path)
        print(f'======= loading property checkpoint from {checkpoint_path}')
        self.property_checkpoint = torch.load(checkpoint_path, map_location=device)
        self.property_model.load_state_dict(self.property_checkpoint['model_state_dict'])
        self.property_model = self.property_model.to(device)
        self.vocab = get_vocab_from_trained_model(self.config.classifier_guidance.onmt_checkpoint_path)
        self.n_candidates_to_evaluate = min(n_candidates_to_evaluate, len(self.vocab))
        self.save_tensors = True
        self.target_classes = target_classes
        self.node_depth = None
        self.conditional_starting_materials = None
        self.conditional_targets = None
        self.stop_guidance = False
        self.immediate_target = None
        self.original_scores_weight = 1.0
        self.keep_using_conditiong = True

    def _select_topn_candidates(self, original_scores):
        '''
            Select the topn candidates to apply classifier guidance to, based on the n_candidates_to_evaluate parameter.

            original_scores: [batch_size, beam_size, vocab_size]
            returns: topn_scores: [batch_size*beam_size, n_candidates_to_evaluate]
                     topn_ids: [batch_size*beam_size, n_candidates_to_evaluate]
        '''
        batch_size = original_scores.shape[0] // self.translator.beam_size
        vocab_size = original_scores.shape[-1]
        # Reshape for easier token selection
        original_scores_reshaped = original_scores.view(batch_size, self.translator.beam_size, vocab_size)
        # Select top-n tokens per beam
        topn_scores, topn_ids = [], []
        for i in range(batch_size):
            for j in range(self.translator.beam_size):
                scores = original_scores_reshaped[i, j]
                scores_topn, ids_topn = torch.topk(scores, self.n_candidates_to_evaluate)
                topn_scores.append(scores_topn)
                topn_ids.append(ids_topn)
        topn_scores = torch.stack(topn_scores).view(batch_size*self.translator.beam_size, self.n_candidates_to_evaluate)
        topn_ids = torch.stack(topn_ids).view(batch_size*self.translator.beam_size, self.n_candidates_to_evaluate)

        return topn_scores, topn_ids

    def _create_candidate_sequences(self, alive_seq, topn_ids):
        '''
            Create sequences with topn candidates for classifier evaluation
        '''
        batch_size = alive_seq.shape[0] // self.translator.beam_size
        # Create sequences with topn candidates for classifier evaluation
        all_evaluated_seqs = []
        for beam_idx in range(batch_size*self.translator.beam_size):
            beam_seq = alive_seq[beam_idx:beam_idx+1]  # Get the current beam sequence
            for token_id in topn_ids[beam_idx]:
                # Create new sequence with the candidate token
                new_seq = torch.cat([beam_seq, token_id.view(1, 1)], dim=1)
                all_evaluated_seqs.append(new_seq)
        # Stack all sequences for batch evaluation
        eval_seqs = torch.cat(all_evaluated_seqs, dim=0) # first iteration should have (batch_size, min_length_for_guidance+1)

        return eval_seqs
    
    def _evaluate_with_token_prefix_similarity(self, eval_seqs):
        eval_seqs_batch = [[self.vocab[e] for e in eval_seqs[i]] \
                                    for i in range(eval_seqs.shape[0])]
        #conditional_starting_material = self.conditional_starting_material.split('<s>')[-1].split('</s>')[0]
        target_tokens = smi_tokenizer(self.conditional_starting_material)[0].split()                          
        distances = [get_token_prefix_similarity(target_tokens=target_tokens, 
                                                 generated_tokens=eval_seqs_batch[i]) \
                                                    for i in range(len(eval_seqs_batch))]
        return torch.tensor(distances).unsqueeze(1).to(device)

    def _evaluate_with_token_prefix_similarity_vectorized(self, eval_seqs):
        """
        Vectorized version of evaluate_with_token_prefix_similarity
        
        Args:
            eval_seqs: tensor of shape (batch_size, seq_len) containing token indices
        """
        # Convert target tokens to indices if needed
        target_indices = turn_seq_to_ids(self.conditional_starting_material, \
                                        onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path)
        if target_indices is None:
            print(f'==== conditional starting material: {self.conditional_starting_material}')
            print(f'Conditional starting material contains tokens not in the vocabulary')
            similarities = torch.zeros(eval_seqs.shape[0], device=device)
            return similarities.unsqueeze(1)
        target_indices = target_indices.to(device)
        eval_seqs = eval_seqs.to(device)
        # Compute similarities using vectorized operations
        similarities = get_token_prefix_similarity_vectorized(eval_seqs, target_indices)
        return similarities.unsqueeze(1)
        
    def _evaluate_with_edit_distance(self, eval_seqs):
        eval_seqs_batch = [[self.vocab[e] for e in eval_seqs[i]] \
                                    for i in range(eval_seqs.shape[0])]
        distances = [get_levenstein_similarity(''.join(eval_seqs_batch[i]), self.conditional_starting_material) \
                        for i in range(len(eval_seqs_batch))]
        return torch.tensor(distances).unsqueeze(1).to(device)
    
    def _evaluate_with_classifier(self, eval_seqs, with_starting_material_similarity=False, with_target_similarity=False):
        # TODO: might have to go over these in batches
        print(f'==== running property model on {eval_seqs.shape[0]} sequences, with batch size {self.config.classifier_guidance.search_batch_size}')
        self.property_model.eval()
        if eval_seqs.shape[1] > self.config.classifier_guidance.max_length_for_guidance:
            return torch.zeros_like(eval_seqs, device=device).float()
        all_scores_values_normalized = []
        #print(f"Memory before evaluation: {torch.cuda.memory_allocated()/1e9:.2f} GB")
        #torch.cuda.empty_cache()  # Clear cache before evaluation
        with torch.no_grad():
        #with torch.cuda.amp.autocast(enabled=False):
            for i in range(0, eval_seqs.shape[0], self.config.classifier_guidance.search_batch_size):
                #try:
                eval_seqs_batch = eval_seqs[i:i+self.config.classifier_guidance.search_batch_size]
                eval_seqs_batch = eval_seqs_batch.to(device)
                #print(f'eval_seqs_batch.shape {eval_seqs_batch.shape}')
                # concatenate conditional starting material if available
                #print(f'==== conditional_starting_material: {self.conditional_starting_material}, with_starting_material_similarity: {with_starting_material_similarity}')
                #print(f'==== conditional_target: {self.conditional_target}, with_target_similarity: {with_target_similarity}')
                if self.config.classifier_guidance.with_product and self.product_smiles is not None:
                    product_smiles = self.product_smiles[self.batch_indices]
                    product_seq = [turn_ids_to_seq(p, onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path) for p in product_smiles]
                    print(f'Using product sequences for guidance, product_seq: {product_seq}')
                    last_idx = min(i+self.config.classifier_guidance.search_batch_size, i+eval_seqs_batch.shape[0])
                    product_smiles = product_smiles[i:last_idx]
                    eval_seqs_batch = torch.cat([eval_seqs_batch, product_smiles], dim=1)
                if self.config.classifier_guidance.name=='similarity' \
                    and self.config.classifier_guidance.similarity_target=='starting_material'\
                    and self.conditional_starting_materials is not None and with_starting_material_similarity:
                    # eval_seqs_batch_str = [self.turn_ids_to_seq(eval_seqs_batch[i]) for i in range(eval_seqs_batch.shape[0])]
                    # distances = [get_levenstein_similarity(eval_seqs_batch_str[i][3:], self.conditional_starting_material) for i in range(len(eval_seqs_batch_str))]
                    # print(f'len(distances): {len(distances)}')
                    # eval_seqs_batch = torch.cat([eval_seqs_batch, torch.tensor(distances).unsqueeze(1)], dim=1)
                    # starting_material_seq = turn_seq_to_ids(self.conditional_starting_materials, \
                    #                                         onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path)
                    # starting_material_seq = torch.cat([turn_seq_to_ids(
                    #     s, 
                    #     onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path
                    # ) for s in self.conditional_starting_materials])
                    # if starting_material_seq is None:
                    #     print(f'==== conditional starting material: {self.conditional_starting_materials}')
                    #     print(f'Conditional starting material contains tokens not in the vocabulary')
                    # else:
                        # TODO: might have to add end token
                        #print(f'added conditional starting material {self.conditional_starting_material}')
                        #starting_material_seq = starting_material_seq.unsqueeze(0).repeat((eval_seqs_batch.shape[0],1))
                    conditional_starting_materials = self.conditional_starting_materials[self.batch_indices]
                    last_idx = min(i+self.config.classifier_guidance.search_batch_size, i+eval_seqs_batch.shape[0])
                    conditional_starting_materials = conditional_starting_materials[i:last_idx]
                    eval_seqs_batch = torch.cat([eval_seqs_batch, conditional_starting_materials], dim=1)
                if self.config.classifier_guidance.name=='similarity' \
                    and self.config.classifier_guidance.similarity_target=='main_target'\
                    and self.conditional_target is not None and with_target_similarity:
                    # target_seq = turn_seq_to_ids(self.conditional_target,\
                    #                             onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path)
                    # if target_seq is None:
                    #     print(f'==== target: {self.conditional_target}')
                    #     print(f'Conditional target contains tokens not in the vocabulary')
                    # else:
                    #     # TODO: might have to add end token
                    #     #print(f'added conditional starting material {self.conditional_starting_material}')
                    #     target_seq = target_seq.unsqueeze(0).repeat((eval_seqs_batch.shape[0],1))
                    conditional_targets = self.conditional_targets[self.batch_indices]
                    last_idx = min(i+self.config.classifier_guidance.search_batch_size, i+eval_seqs_batch.shape[0])
                    conditional_targets = conditional_targets[i:last_idx]
                    eval_seqs_batch = torch.cat([eval_seqs_batch, conditional_targets], dim=1)

                    #raise ValueError(f"Conditional starting material contains tokens not in the vocabulary")
                scores_normalized = self.property_model(eval_seqs_batch)
                if isinstance(scores_normalized, tuple):
                    scores_values, scores_log_var_normalized = scores_normalized
                    all_scores_values_normalized.append(scores_log_var_normalized)
                elif isinstance(scores_normalized, torch.Tensor):
                    all_scores_values_normalized.append(scores_normalized)
                else:
                    raise ValueError(f"Unknown scores type: {type(scores_normalized)}")
                    #del eval_seqs_batch, scores_normalized, pred, prop_pred, loss, accuracy
                #torch.cuda.empty_cache()
                # except Exception as e:
                #     print(f"Error on batch {i}: {e}")
                #     print(f"Memory at error: {torch.cuda.memory_allocated()/1e9:.2f} GB")
                # exit()
            scores_values = torch.cat(all_scores_values_normalized, dim=0)
        # unnormalize scores
        if self.config.classifier_guidance.normalize_prediction and self.config.classifier_guidance.as_regression:
            scores_values = scores_values * self.property_checkpoint['target_std'] + self.property_checkpoint['target_mean']
        return scores_values

    def define_property_model(
        self,
        onmt_checkpoint_path,
        checkpoint_path,
        output_size=None
    ):
        vocab = get_vocab_from_trained_model(onmt_checkpoint_path)
        # TODO: override this
        property_model = PropertyPredictor(
            self.config, 
            len(vocab), 
            output_size=output_size
        )
        checkpoint_path = os.path.join(PROJECT_ROOT,
                                        'checkpoints',
                                        checkpoint_path)
        property_checkpoint = torch.load(checkpoint_path, map_location=device)
        property_model.load_state_dict(property_checkpoint['model_state_dict'])
        property_model = property_model.to(device)
        return property_model, property_checkpoint

    def _evaluate_with_classifier_tanimoto(
        self, 
        eval_seqs
    ):
        # TODO: might have to go over these in batches
        print(f'==== running property model on {eval_seqs.shape[0]} sequences, with batch size {self.config.classifier_guidance.search_batch_size}')
        property_model, property_checkpoint = self.define_property_model(
            onmt_checkpoint_path=self.config.classifier_guidance.tanimoto.onmt_checkpoint_path,
            checkpoint_path=self.config.classifier_guidance.tanimoto.checkpoint_path,
            output_size=1
        )
        property_model.eval()
        if eval_seqs.shape[1] > self.config.classifier_guidance.max_length_for_guidance:
            return torch.zeros_like(eval_seqs, device=device).float()
        all_scores_values = []
        with torch.no_grad():
            for i in range(0, eval_seqs.shape[0], self.config.classifier_guidance.search_batch_size):
                #try:
                eval_seqs_batch = eval_seqs[i:i+self.config.classifier_guidance.search_batch_size]
                eval_seqs_batch = eval_seqs_batch.to(device)
                conditional_starting_materials = self.conditional_starting_materials[self.batch_indices]
                last_idx = min(i+self.config.classifier_guidance.search_batch_size, i+eval_seqs_batch.shape[0])
                conditional_starting_materials = conditional_starting_materials[i:last_idx]
                eval_seqs_batch = torch.cat([eval_seqs_batch, conditional_starting_materials], dim=1)
                scores_normalized = self.property_model(eval_seqs_batch)
                all_scores_values.append(scores_normalized)
            scores_values = torch.cat(all_scores_values, dim=0)
        # unnormalize scores
        scores_values = scores_values * property_checkpoint['target_std'] + property_checkpoint['target_mean']
        return scores_values

    def _evaluate_with_classifier_reaction_type(
        self, 
        eval_seqs
    ):
        # TODO: might have to go over these in batches
        print(f'==== running property model on {eval_seqs.shape[0]} sequences, with batch size {self.config.classifier_guidance.search_batch_size}')
        property_model, property_checkpoint = self.define_property_model(
            onmt_checkpoint_path=self.config.classifier_guidance.reaction_type.onmt_checkpoint_path,
            checkpoint_path=self.config.classifier_guidance.reaction_type.checkpoint_path,
            output_size=self.config.classifier_guidance.reaction_type.num_classes
        )
        property_model.eval()
        if eval_seqs.shape[1] > self.config.classifier_guidance.max_length_for_guidance:
            return torch.zeros_like(eval_seqs, device=device).float()
        all_scores_values = []
        with torch.no_grad():
            for i in range(0, eval_seqs.shape[0], self.config.classifier_guidance.search_batch_size):
                eval_seqs_batch = eval_seqs[i:i+self.config.classifier_guidance.search_batch_size]
                eval_seqs_batch = eval_seqs_batch.to(device)
                scores_normalized = property_model(eval_seqs_batch)
                all_scores_values.append(scores_normalized)
            scores_values = torch.cat(all_scores_values, dim=0)
        return scores_values
    
    def _process_scores(self, scores_values, batch_size):
        '''
            Process the scores from the classifier to get a probability for each candidate

            scores_values: [batch_size*beam_size, n_candidates_to_evaluate]
            returns: classifier_scores: [batch_size*beam_size, n_candidates_to_evaluate]
        '''
        if self.config.classifier_guidance.property == 'max_tanimoto' or self.config.classifier_guidance.property == 'tanimoto_like_tango':
            # classifier_scores = scores_values
            # NOTE: the interpretation here is take the next token with the highest similarity
            mean_similarity = scores_values.mean()
            classifier_scores = scores_values - mean_similarity # use relative similarity as the classifier score
        elif self.config.classifier_guidance.property == 'reaction_type':
            # mean_similarity = scores_values.mean()
            # classifier_scores = scores_values - mean_similarity # use relative similarity as the classifier score
            # distance_to_threshold = self.get_distance_to_threshold_based_on_property(scores_values)
            # classifier_scores = distance_to_threshold.float()
            # input_val = self.config.classifier_guidance.sigmoid_steepness * (distance_to_threshold)
            # classifier_scores = F.logsigmoid(input_val) # log(sigmoid(input_val))        elif self.config.classifier_guidance.property == 'reaction_type':
            reaction_types = self.reaction_types[self.batch_indices]
            # print('='*100)
            # print(f'self.reaction_types.shape: {self.reaction_types.shape}')
            # print(f'self.batch_indices.shape: {self.batch_indices.shape}')
            # print(f'reaction_types {reaction_types}')
            # print(f'reaction_types.shape: {reaction_types.shape}')
            # print(f'scores_values.shape: {scores_values.shape}')
            print(f'reaction_types.shape: {reaction_types.shape}')
            print(f'reaction_types {reaction_types[:10]}')
            classifier_scores = scores_values.gather(dim=-1, index=reaction_types.unsqueeze(-1)).squeeze(-1)
            #print(f'classifier_scores.shape: {classifier_scores.shape}')
        elif self.config.classifier_guidance.property == 'tanimoto_and_reaction_type':
            # steering towards reaction type
            reaction_type_scores_values = scores_values[...,0]
            reaction_types = self.reaction_types[self.batch_indices]
            reaction_type_scores_values = reaction_type_scores_values.gather(dim=-1, index=reaction_types.unsqueeze(-1)).squeeze(-1)
            # steering towards tanimoto
            tanimoto_scores_values = scores_values[...,0,1]
            mean_similarity = tanimoto_scores_values.mean()
            tanimoto_scores_values = tanimoto_scores_values - mean_similarity # use relative similarity as the classifier score
            # combine the two
            t_w = self.config.classifier_guidance.tanimoto.steering_weight
            r_w = self.config.classifier_guidance.reaction_type.steering_weight
            classifier_scores =  t_w * tanimoto_scores_values + r_w * reaction_type_scores_values
        elif self.config.classifier_guidance.property == 'length':
            classifier_scores = scores_values[...,int(self.config.classifier_guidance.target_class_index)]
        else:
            raise ValueError(f"Property {self.config.classifier_guidance.property} not supported for guided score")
        # Reshape classifier scores to match topn_scores shape
        classifier_scores = classifier_scores.view(batch_size*self.translator.beam_size, self.n_candidates_to_evaluate)
        #print(f'after view')
        max_confidence = F.softmax(classifier_scores, dim=-1).max(dim=-1).values
        #print(f'after softmax')
        # ignore guidance if max confidence is below threshold
        if all(max_confidence < self.config.classifier_guidance.min_confidence_for_guidance):
            #print(f'==== max confidence: {max_confidence}')
            self.guidance_scale = 0.0
        #print(f'about to return classifier_scores')
        return classifier_scores
    
    def _apply_classifier_guidance(self, original_scores, classifier_scores, topn_scores, topn_ids):
        '''
            Apply classifier guidance to the original scores
        '''
        batch_size = original_scores.shape[0] // self.translator.beam_size
        combined_scores = self.combine_scores(topn_scores, classifier_scores, self.guidance_scale)
        # Update the original score matrix with guidance
        guided_scores = original_scores.clone()
        for beam_idx in range(batch_size*self.translator.beam_size):
            guided_scores[beam_idx, topn_ids[beam_idx]] = combined_scores[beam_idx]
        if self.guidance_scale == 0.:
            # if not using guidance, the scores should be the same
            if not self.are_scores_close(original_scores,
                                         guided_scores,
                                        tolerance=self.are_scores_close_tolerance, 
                                        debug=self.are_scores_close_debug):
                print(f'==== Scores are not close')
                print(f'==== original_scores: {original_scores}')
                print(f'==== guided_scores: {guided_scores}')
                raise ValueError("Scores are not close")
        return guided_scores
    
    def _save_tensors(self, original_scores, guided_scores, classifier_scores, topn_scores, topn_ids):
        '''
            Save the tensors to a file.
        '''
        os.makedirs(os.path.join(PROJECT_ROOT, "debug_tensors", self.config.classifier_guidance.experiment_name, self.config.classifier_guidance.dataset.data_dir), exist_ok=True)
        torch.save(guided_scores, os.path.join(PROJECT_ROOT, "debug_tensors", self.config.classifier_guidance.experiment_name, self.config.classifier_guidance.dataset.data_dir, "guided_scores.pt"))
        torch.save(original_scores, os.path.join(PROJECT_ROOT, "debug_tensors", self.config.classifier_guidance.experiment_name, self.config.classifier_guidance.dataset.data_dir, "original_scores.pt"))
        torch.save(classifier_scores, os.path.join(PROJECT_ROOT, "debug_tensors", self.config.classifier_guidance.experiment_name, self.config.classifier_guidance.dataset.data_dir, "classifier_scores.pt"))
        torch.save(topn_scores, os.path.join(PROJECT_ROOT, "debug_tensors", self.config.classifier_guidance.experiment_name, self.config.classifier_guidance.dataset.data_dir, "topn_scores.pt"))
        torch.save(topn_ids, os.path.join(PROJECT_ROOT, "debug_tensors", self.config.classifier_guidance.experiment_name, self.config.classifier_guidance.dataset.data_dir, "topn_ids.pt"))
        self.save_tensors = False

    def _should_apply_classifier_guidance(self, alive_seq):
        '''
            Check if we should apply classifier guidance to the current sequence
        '''
        # TODO: we use max length guidance because the property predictor max len is capped at 500
        seq_lenghth_conditions = alive_seq.shape[1] >= self.config.classifier_guidance.min_length_for_guidance \
                                    and alive_seq.shape[1] <= self.config.classifier_guidance.max_length_for_guidance
        if not self.config.classifier_guidance.as_regression:
            # if we're doing classification, ignore the guidance if no class is specified
            cond = self.guidance_scale > 0.0 and seq_lenghth_conditions and (
                self.config.classifier_guidance.target_class_index != -1 
                or not torch.all(self.reaction_types==-1)
            )
            # print(f'cond {cond}')
            # print(f'self.config.classifier_guidance.target_class_index {self.config.classifier_guidance.target_class_index}')
            # print(f'torch.all(self.reaction_types==-1) {torch.all(self.reaction_types==-1)}')
            return cond
        else:
            return self.guidance_scale > 0.0 and seq_lenghth_conditions
    
    def _penalize_eos_token(self, original_scores):
        '''
            Penalize the EOS token for short sequences
        '''
        original_scores[..., self.translator._tgt_eos_idx] += self.config.classifier_guidance.eos_penalty
        return original_scores
    
    def _assign_target_class_index(self):
        if not self.config.classifier_guidance.as_regression and self.config.classifier_guidance.multi_step_classes:
            if not self.config.classifier_guidance.as_regression:
                if self.node_depth is not None and self.node_depth < len(self.target_classes):
                    print(f'==== assigning target class index for node depth {self.node_depth} with len(target_classes) {len(self.target_classes)}')
                    self.config.classifier_guidance.target_class_index = self.target_classes[self.node_depth]
                else:
                    self.config.classifier_guidance.target_class_index = -1
            else:
                self.config.classifier_guidance.target_class_index = -1
            print(f'==== target class index: {self.config.classifier_guidance.target_class_index}')

    def _evaluate_with_fake_tanimoto(self, eval_seqs):
        eval_seqs_batch = [[self.vocab[e] for e in eval_seqs[i]] \
                                    for i in range(eval_seqs.shape[0])]
        distances = [self.get_fake_tanimoto(eval_seqs_batch[i], self.conditional_starting_materials) \
                        for i in range(len(eval_seqs_batch))]
        return torch.tensor(distances).unsqueeze(1)

    def get_fake_tanimoto(self, eval_seq, starting_material):
        fake_tanimoto_dict = {('C', 1): 0.38,
                                ('C', 2): 0.38,
                                ('C', 3): 0.38,
                                ('C', 4): 0.38,
                                ('O', 5): 0.38,
                                ('C', 6): 0.38,
                                ('(', 7): 0.38,
                                ('=', 8): 0.38,
                                ('O', 9): 0.38,
                                (')', 10): 0.38,
                                ('N', 11): 0.38,
                                ('1', 12): 0.38,
                                ('C', 13): 0.38,
                                ('C', 14): 0.38,
                                ('N', 15): 0.38,
                                ('C', 16): 0.38,
                                ('C', 17): 0.38,
                                ('1', 18): 0.38,
                                ('.', 19): 0.38,
                                ('C', 20): 0.38,
                                ('C', 21): 0.38,
                                ('(', 22): 0.38,
                                ('C', 23): 0.38,
                                (')', 24): 0.38,
                                ('(', 25): 0.38,
                                ('C', 26): 0.38,
                                (')', 27): 0.38,
                                ('O', 28): 0.38,
                                ('C', 29): 0.38,
                                ('(', 30): 0.38,
                                ('=', 31): 0.38,
                                ('O', 32): 0.38,
                                (')', 33): 0.38,
                                ('N', 34): 0.38,
                                ('[C@@H]', 35): 0.38,
                                ('(', 36): 0.38,
                                ('C', 37): 0.38,
                                ('C', 38): 0.38,
                                ('C', 39): 0.38,
                                ('O', 40): 0.38,
                                ('[Si]', 41): 0.38,
                                ('(', 42): 0.38,
                                ('c', 43): 0.38,
                                ('1', 44): 0.38,
                                ('c', 45): 0.38,
                                ('c', 46): 0.38,
                                ('c', 47): 0.38,
                                ('c', 48): 0.38,
                                ('c', 49): 0.38,
                                ('1', 50): 0.38,
                                (')', 51): 0.38,
                                ('(', 52): 0.38,
                                ('c', 53): 0.38,
                                ('1', 54): 0.38,
                                ('c', 55): 0.38,
                                ('c', 56): 0.38,
                                ('c', 57): 0.38,
                                ('c', 58): 0.38,
                                ('c', 59): 0.38,
                                ('1', 60): 0.38,
                                (')', 61): 0.38,
                                ('C', 62): 0.38,
                                ('(', 63): 0.38,
                                ('C', 64): 0.38,
                                (')', 65): 0.38,
                                ('(', 66): 0.38,
                                ('C', 67): 0.38,
                                (')', 68): 0.38,
                                ('C', 69): 0.38,
                                (')', 70): 0.38,
                                ('C', 71): 0.38,
                                ('(', 72): 0.38,
                                ('=', 73): 0.38,
                                ('O', 74): 0.38,
                                (')', 75): 0.38,
                                ('O', 76): 0.38,
                                ('</s>', 77): 0.38}
        last_token = (eval_seq[-1], len(eval_seq)-1)
    
        return fake_tanimoto_dict.get(last_token, 0) 

    def _debug_(self, original_scores, guided_scores, top_n=5):
        # Get top N sequences by their best original scores
        max_original_scores = original_scores.max(dim=-1).values
        top_beam_indices = max_original_scores.topk(top_n).indices
        
        original_top_indices = original_scores.argmax(dim=-1)
        guided_top_indices = guided_scores.argmax(dim=-1)
        
        print(f"Showing top {top_n} sequences by original scores:")
        for i, beam_idx in enumerate(top_beam_indices):
            orig_top_token = original_top_indices[beam_idx]
            guided_top_token = guided_top_indices[beam_idx]
            orig_token_str = self.vocab[orig_top_token]
            guided_token_str = self.vocab[guided_top_token]
            
            print(f"Rank {i+1} (Beam {beam_idx}):")
            print(f"  Original top: {orig_token_str} (score: {original_scores[beam_idx, orig_top_token]:.4f})")
            print(f"  Guided top: {guided_token_str} (score: {guided_scores[beam_idx, guided_top_token]:.4f})")
            print(f"  Changed: {orig_top_token != guided_top_token}")
            print()

    def debug_next_sequences(self, guided_scores, eval_seq_str, classifier_scores, topn_ids, top_n=5):
        # Get top sequences by guided scores
        max_guided_scores = guided_scores.max(dim=-1).values
        top_beam_indices = max_guided_scores.topk(top_n).indices
        guided_top_indices = guided_scores.argmax(dim=-1)
        
        print(f"Top {top_n} next sequences:")
        for i, beam_idx in enumerate(top_beam_indices):
            guided_top_token = guided_top_indices[beam_idx]
            # Find the sequence that corresponds to this beam + guided token choice
            guided_token_pos = (topn_ids[beam_idx] == guided_top_token).nonzero()
            if len(guided_token_pos) > 0:
                seq_idx = beam_idx * self.n_candidates_to_evaluate + guided_token_pos[0].item()
                if seq_idx < len(eval_seq_str):
                    next_seq = eval_seq_str[seq_idx]
                    score = guided_scores[beam_idx, guided_top_token].item()
                    classifier_score = classifier_scores[beam_idx, guided_token_pos].item()
                    print(f"{i+1}. {next_seq} (score: {score:.3f}) (classifier_score: {classifier_score:.3f})")
                else:
                    print(f"{i+1}. INDEX_ERROR")
            else:
                print(f"{i+1}. TOKEN_NOT_FOUND")
        
    def seq_contains_special_tokens(self, seq):
        return any(token<4 for token in seq)

    def forward(self, *args, **kwargs):
        # TODO: add smthg to only introduce guidance later in the beam search
        # e.g. once alive_seq is long enough... or change the guidance scale based on the length of the sequence
        # classifier_guidance.n_candidates_to_evaluate=72
        if 'similarity_type' in self.config.classifier_guidance.keys()\
            and self.config.classifier_guidance.similarity_type=='enforce_starting_material' \
            and self.node_depth==self.config.classifier_guidance.enforce_starting_material_at_depth \
            and self.keep_using_conditiong:
            # prepare the conditional material as well
            # not sure if should have the . at the end? probably not
            if self.conditional_starting_materials.startswith('.'):
                self.conditional_starting_materials = '<s>'+self.conditional_starting_materials.split('.')[-1].split('</s>')[0]+'.'
            self.guidance_scale = 1
            self.config.classifier_guidance.min_length_for_guidance = 0
            self.config.classifier_guidance.n_candidates_to_evaluate = min(self.config.classifier_guidance.n_candidates_to_evaluate, len(self.vocab))
            self.n_candidates_to_evaluate = min(self.n_candidates_to_evaluate, len(self.vocab))
            self.original_scores_weight = 0.0

        self._assign_target_class_index()
        original_scores = self.original_generator(*args, **kwargs)
        alive_seq = self.translator.decode_strategy.alive_seq
        #print(f'==== alive_seq.shape= {alive_seq.shape}')
        # TODO: maybe look into why we get wrong intermediate tokens when enforcing starting material
        # if alive_seq.shape[-1]>1 and any(self.seq_contains_special_tokens(seq[1:]) for seq in alive_seq):
        #     print(f'==== alive_seq contains special tokens')
        batch_size = original_scores.shape[0] // self.translator.beam_size
        # only start with guidance after a certain length and target_class_index is not -1
        if not self._should_apply_classifier_guidance(alive_seq):
            # Penalize EOS token for short sequences
            if self.keep_using_conditiong:
                original_scores = self._penalize_eos_token(original_scores)
            return original_scores
        topn_scores, topn_ids = self._select_topn_candidates(original_scores)
        eval_seqs = self._create_candidate_sequences(alive_seq, topn_ids)
        # true_eval_seq = [(st,id) for id,st in enumerate(eval_seq_str) if st.startswith(self.conditional_starting_material)]
        #if len(true_eval_seq) > 0 and len(smi_tokenizer(true_eval_seq[0][0].split('<s>')[-1].split('</s>')[0])[0].split()) > 29:
        # if len(true_eval_seq) > 0:
        #     print(f'here')
        if self.config.classifier_guidance.property == 'max_tanimoto' or self.config.classifier_guidance.property == 'tanimoto_like_tango':
            if self.config.classifier_guidance.similarity_type == 'token_prefix':
                scores_values = self._evaluate_with_token_prefix_similarity_vectorized(eval_seqs)
                conditional_starting_material_ids = turn_seq_to_ids(self.conditional_starting_materials, onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path)
                #if eval_seqs[0].shape[0]==conditional_starting_material_ids.shape[0]:
                    # print(f'==== eval_seqs[0].shape[0]: {eval_seqs[0].shape[0]}')
                    # print(f'==== conditional_starting_material: {self.conditional_starting_material}')
                    # print(f'==== eval_seqs[0]: {turn_ids_to_seq(eval_seqs[0], onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path)}')
                    # print(f'==== here')
            elif self.config.classifier_guidance.similarity_type == 'fake_tanimoto':
                scores_values = self._evaluate_with_fake_tanimoto(eval_seqs)
            elif self.config.classifier_guidance.similarity_type == 'target_and_starting_material_similarity':
                # combine the scores from the starting material and the target
                scores_values_starting_material = self._evaluate_with_classifier(eval_seqs, with_starting_material_similarity=True, 
                            with_target_similarity=False)
                scores_values_target = self._evaluate_with_classifier(eval_seqs, with_starting_material_similarity=False, 
                            with_target_similarity=True)
                w = self.config.classifier_guidance.target_and_starting_material_combination_weight
                scores_values =  w * scores_values_starting_material + (1 - w) * scores_values_target
            elif self.config.classifier_guidance.similarity_type == 'enforce_starting_material' \
                    and self.node_depth==self.config.classifier_guidance.enforce_starting_material_at_depth \
                    and self.keep_using_conditiong:
                    # prepare the conditional material as well
                    # not sure if should have the . at the end? probably not
                    # eval_seq_str = [turn_ids_to_seq(seq, onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path) for seq in eval_seqs]
                    # alive_seq_str = [turn_ids_to_seq(seq, onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path) for seq in alive_seq]  
                    scores_values = self._evaluate_with_token_prefix_similarity_vectorized(eval_seqs)
                    conditional_starting_material_ids = turn_seq_to_ids(self.conditional_starting_materials, onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path)
                    # TODO: change condition to be that the conditional material is fully generated
                    first_eval_seq = turn_ids_to_seq(eval_seqs[0], onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path)
                    if eval_seqs[0].shape[0]==conditional_starting_material_ids.shape[0]:
                        # print(f'Enforcing starting material at depth {self.node_depth} with condition {self.config.classifier_guidance.enforce_starting_material_at_depth}')
                        # print(f'first_eval_seq: {first_eval_seq}')
                        # print(f'conditional_starting_material: {self.conditional_starting_material}')
                        self.original_scores_weight = 1.0
                        self.guidance_scale = 0.
                        self.keep_using_conditiong = False
                        # eval_seq_str = [turn_ids_to_seq(seq, onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path) for seq in eval_seqs]
                        # alive_seq_str = [turn_ids_to_seq(seq, onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path) for seq in alive_seq] 
                        print(f'==== seq_str:')
            else:
                # tanimoto, fingerprint, etc
                scores_values = self._evaluate_with_classifier(eval_seqs, with_starting_material_similarity=True, with_target_similarity=False)
            # else:
            #     raise ValueError(f"Similarity type {self.config.classifier_guidance.similarity_type} not found")
        elif self.config.classifier_guidance.property == 'tanimoto_and_reaction_type':
            tanimoto_scores_values = self._evaluate_with_classifier_tanimoto(
                eval_seqs
            )
            reaction_type_scores_values = self._evaluate_with_classifier_reaction_type(
                eval_seqs
            )
            scores_values = torch.stack([reaction_type_scores_values, tanimoto_scores_values], dim=-1)
        else:
            scores_values = self._evaluate_with_classifier(eval_seqs)
        # if scores_values.max().item()==1.:    
        #     print(f'==== scores_values: {scores_values}')
        #     self.stop_guidance = True

        classifier_scores = self._process_scores(scores_values, batch_size)
        #print('about to apply classifier guidance')
        guided_scores = self._apply_classifier_guidance(original_scores, classifier_scores, topn_scores, topn_ids)
        #print('after applying classifier guidance')
        #self._debug_(original_scores, guided_scores)
        #self.debug_next_sequences(guided_scores, eval_seq_str, classifier_scores, topn_ids, top_n=10)
        # save the scores for debugging/figure plotting in files
        if self.guidance_scale>0.0 and self.save_tensors:
            self._save_tensors(original_scores, guided_scores, classifier_scores, topn_scores, topn_ids)
        #print('about to return guided scores')
        return guided_scores
    
    def get_distance_to_threshold_based_on_property(self, scores_values):
        if self.config.classifier_guidance.property == 'toxicity':
            return self.config.classifier_guidance.prediction_threshold - scores_values # positive is less toxic
        elif self.config.classifier_guidance.property == 'yield':
            return scores_values - self.config.classifier_guidance.prediction_threshold # positive is higher yield
        elif self.config.classifier_guidance.property == 'sa_score':
            return self.config.classifier_guidance.prediction_threshold - scores_values # positive is lower sa_score
        elif self.config.classifier_guidance.property == 'np_score':
            return self.config.classifier_guidance.prediction_threshold - scores_values # positive is lower np_score
        elif self.config.classifier_guidance.property == 'length':
            return scores_values - self.config.classifier_guidance.prediction_threshold # less than max length is better
        elif self.config.classifier_guidance.property == 'tanimoto':
            return scores_values - self.config.classifier_guidance.prediction_threshold # positive is higher tanimoto
        else:
            raise ValueError(f"Property {self.config.classifier_guidance.property} not found")
        
    def are_scores_close(self, curr_scores, guided_scores, tolerance=1e-6, debug=False):
        # Check if infinities are in the same positions
        inf_same_pos = ((torch.isinf(curr_scores) & torch.isinf(guided_scores)).sum() == 
                        torch.isinf(curr_scores).sum())
        if debug:
            print(f"Infinities in same positions: {inf_same_pos}")
        # Look at non-infinite values
        finite_mask = ~torch.isinf(curr_scores) & ~torch.isinf(guided_scores)
        if finite_mask.any():
            curr_finite = curr_scores[finite_mask]
            guided_finite = guided_scores[finite_mask]
            max_finite_diff = (curr_finite - guided_finite).abs().max().item()
            if debug:
                print(f"Max difference in finite values: {max_finite_diff}")
            # Print a few examples of differences
            diff_indices = torch.where((curr_finite - guided_finite).abs() > tolerance)
            if len(diff_indices[0]) > 0:
                for i in range(min(5, len(diff_indices[0]))):
                    idx = diff_indices[0][i]
                    if debug:
                        print(f"Diff at finite value: {curr_finite[idx]} vs {guided_finite[idx]}")
            
            return max_finite_diff < tolerance
    
    def combine_scores(self, original_scores, classifier_scores, guidance_scale=1.0):
        '''
            original_scores: [batch_size, beam_size, vocab_size]
            classifier_scores: [batch_size, beam_size, n_candidates_to_evaluate]
        '''
        if self.config.classifier_guidance.property == 'tanimoto_and_reaction_type':
            assert guidance_scale == 1.0, "Guidance scale for tanimoto_and_reaction_type should be 1.0"+\
                "Use tanimoto_steering_weight and reaction_type_steering_weight to specify each guidance scale."
        classifier_scores = classifier_scores.view(original_scores.shape[0], -1)
        if not self.config.classifier_guidance.combine_renormalize:
            output = original_scores + guidance_scale * classifier_scores  # Simplified example
        else:
            combined_log_probs = original_scores + guidance_scale * classifier_scores
            # NOTE: no need to renormalize here because beam search uses relative ranking to select beams
            log_sum = torch.logsumexp(combined_log_probs, dim=-1, keepdim=True)
            normalized_log_probs = combined_log_probs - log_sum
            #replace nan with -inf
            output = torch.nan_to_num(normalized_log_probs, -float('inf'))
        return output
        #return combined_log_probs

    def forward_old(self, *args, **kwargs):
        # TODO: add smthg to only introduce guidance later in the beam search
        # e.g. once alive_seq is long enough... or change the guidance scale based on the length of the sequence
        original_scores = self.original_generator(*args, **kwargs)
        alive_seq = self.translator.decode_strategy.alive_seq
        # only start with guidance after a certain length
        if alive_seq.shape[1] < self.config.classifier_guidance.min_length_for_guidance:
            # Penalize EOS token for short sequences
            original_scores[..., self.translator._tgt_eos_idx] += self.config.classifier_guidance.eos_penalty
            return original_scores
        batch_size = original_scores.shape[0] // self.translator.beam_size
        vocab_size = original_scores.shape[-1]
        # Reshape for easier token selection
        original_scores_reshaped = original_scores.view(batch_size, self.translator.beam_size, vocab_size)
        # Select top-n tokens per beam
        topn_scores, topn_ids = [], []
        for i in range(batch_size):
            for j in range(self.translator.beam_size):
                scores = original_scores_reshaped[i, j]
                scores_topn, ids_topn = torch.topk(scores, self.n_candidates_to_evaluate)
                topn_scores.append(scores_topn)
                topn_ids.append(ids_topn)
        topn_scores = torch.stack(topn_scores).view(batch_size*self.translator.beam_size, self.n_candidates_to_evaluate)
        topn_ids = torch.stack(topn_ids).view(batch_size*self.translator.beam_size, self.n_candidates_to_evaluate)
        # Create sequences with topn candidates for classifier evaluation
        all_evaluated_seqs = []
        for beam_idx in range(batch_size*self.translator.beam_size):
            beam_seq = alive_seq[beam_idx:beam_idx+1]  # Get the current beam sequence
            for token_id in topn_ids[beam_idx]:
                # Create new sequence with the candidate token
                new_seq = torch.cat([beam_seq, token_id.view(1, 1)], dim=1)
                all_evaluated_seqs.append(new_seq)
        # Stack all sequences for batch evaluation
        eval_seqs = torch.cat(all_evaluated_seqs, dim=0) # first iteration should have (batch_size, min_length_for_guidance+1)
        # TODO: might have to go over these in batches
        print(f'==== running property model on {eval_seqs.shape[0]} sequences, with batch size {self.config.classifier_guidance.search_batch_size}')
        self.property_model.eval()
        all_scores_values_normalized = []
        with torch.no_grad():
            for i in range(0, eval_seqs.shape[0], self.config.classifier_guidance.search_batch_size):  
                eval_seqs_batch = eval_seqs[i:i+self.config.classifier_guidance.search_batch_size]
                eval_seqs_batch = eval_seqs_batch.to(device)
                scores_normalized = self.property_model(eval_seqs_batch)
                if isinstance(scores_normalized, tuple):
                    scores_values, scores_log_var_normalized = scores_normalized
                    all_scores_values_normalized.append(scores_log_var_normalized)
                elif isinstance(scores_normalized, torch.Tensor):
                    all_scores_values_normalized.append(scores_normalized)
                else:
                    raise ValueError(f"Unknown scores type: {type(scores_normalized)}")
            scores_values = torch.cat(all_scores_values_normalized, dim=0)
        # unnormalize scores
        # scores_values = scores_values * self.property_checkpoint['target_std'] + self.property_checkpoint['target_mean']
        # TODO: could use log_var in a smart way if available
        # Using PyTorch's built-in function
        # scores_values_argmax = scores_values
        # scores_values_argmax = scores_values.argmax(dim=-1)
        # scores_values_argmax = torch.where(scores_values_argmax == 0., 7, scores_values_argmax)
        # scores_values_argmax = torch.where(scores_values_argmax == 1., 3, scores_values_argmax)
        # distance_to_threshold = self.get_distance_to_threshold_based_on_property(scores_values_argmax)
        # distance_to_threshold = distance_to_threshold.float()
        # input_val = self.config.classifier_guidance.sigmoid_steepness * (distance_to_threshold)
        # classifier_scores = F.logsigmoid(input_val) # log(sigmoid(input_val))
        # #classifier_scores = scores_values[...,0]
        # # Reshape classifier scores to match topn_scores shape
        # classifier_scores = classifier_scores.view(batch_size*self.translator.beam_size, self.n_candidates_to_evaluate)
        # Combine scores - implement your combine_scores function
        classifier_scores = scores_values[...,int(self.config.classifier_guidance.target_class_index)]
        classifier_scores = classifier_scores.view(batch_size*self.translator.beam_size, self.n_candidates_to_evaluate)
        max_confidence = F.softmax(classifier_scores, dim=-1).max(dim=-1).values
        if all(max_confidence < 0.01):
            print(f'==== max confidence: {max_confidence}')
            self.guidance_scale = 0.0
        combined_scores = self.combine_scores(topn_scores, classifier_scores, self.guidance_scale)
        # Update the original score matrix with guidance
        guided_scores = original_scores.clone()
        for beam_idx in range(batch_size*self.translator.beam_size):
            guided_scores[beam_idx, topn_ids[beam_idx]] = combined_scores[beam_idx]
        if self.guidance_scale == 0.:
            # if not using guidance, the scores should be the same
            if not self.are_scores_close(original_scores, 
                                         guided_scores, 
                                         tolerance=self.are_scores_close_tolerance, 
                                         debug=self.are_scores_close_debug):
                original_scores_path = os.path.join(PROJECT_ROOT, 
                                                    "experiments", 
                                                    "toy_experiment", 
                                                    self.config.classifier_guidance.experiment_name, 
                                                    "original_scores.pt")
                guided_scores_path = os.path.join(PROJECT_ROOT, 
                                                    "experiments", 
                                                    self.config.classifier_guidance.experiment_name, 
                                                    "guided_scores.pt")
                print(f'Scores are not close, saved them in {original_scores_path} and {guided_scores_path}')
                raise ValueError("Scores are not close")
        return guided_scores

    def forward_(self, *args, **kwargs):
        # Get the original scores from the translator's generator
        # TODO: what input goes here? need partially completed sequence
        original_scores = self.original_generator(*args, **kwargs)

        # Apply your classifier guidance/bias
        # The exact implementation depends on your classifier and how you want to combine scores
        # HACK: need to choose topn candidates here to evaluate for the classifier, and from which we choose topk for the beam search
        # choose topn candidates to evaluate for the classifier
        classifier_scores = self.classifier_model(self.translator.alive_seq)  # or whatever inputs your classifier needs


        # HACK: need to choose topn candidates here to evaluate for the classifier, and from which we choose topk for the beam search
        
        # Combine scores (e.g., add, multiply, etc.)
        # This is similar to your pseudocode: original_scores + my_own_bias
        modified_scores = self.combine_scores(original_scores, classifier_scores)
        
        return modified_scores
    
    def combine_scores_(self, original_scores, classifier_scores):
        # Implement your score combination logic
        # Could be simple addition with a weight parameter
        # return original_scores + (lambda_weight * classifier_scores)
        # Or more complex combinations
        combined_log_probs = original_scores + self.guidance_scale * classifier_scores  # Simplified example
        log_sum = torch.logsumexp(combined_log_probs, dim=-1, keepdim=True)
        normalized_log_probs = combined_log_probs - log_sum
        return normalized_log_probs
