#!/usr/bin/env python
""" Translator Class and builder """
# At the top of your file
import sys
import os
from pathlib import Path
# Add the parent directory to sys.path
sys.path.append(str(Path(__file__).parent))
from typing import List, Tuple, Dict, Any
import codecs
import torch
from rdkit import Chem
import torch.nn as nn
import onmt.model_builder
import onmt.decoders.ensemble
from onmt.translate.beam_search import BeamSearch
from onmt.translate.greedy_search import GreedySearch
from onmt.translate.translator import Translator, GeneratorLM
from onmt.constants import ModelTask

from multiguide.onmt.guided_generator import ClassifierGuidedGenerator
from multiguide.property.property_predictor import PropertyPredictor
from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import get_vocab_size, get_vocab_from_trained_model, get_tanimoto
from multiguide.dataset.helpers import turn_ids_to_seq, get_sorted_cano_smiles, get_rxn_insight_info, class_to_idx

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def build_classifier_guided_translator(opt, report_score=True, logger=None, out_file=None, config=None, target_classes=None):
    '''
    This function is a wrapper around the build_classifier_guided_translator function in onmt's translator.
    It adds guidance to the model by using the translator.
    '''
    if out_file is None:
        out_file = codecs.open(opt.output, "w+", "utf-8")

    load_test_model = (
        onmt.decoders.ensemble.load_test_model
        if len(opt.models) > 1
        else onmt.model_builder.load_test_model
    )
    fields, model, model_opt = load_test_model(opt)

    scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt)

    if model_opt.model_task == ModelTask.LANGUAGE_MODEL:
        translator = GeneratorLM.from_opt(
            model,
            fields,
            opt,
            model_opt,
            global_scorer=scorer,
            out_file=out_file,
            report_align=opt.report_align,
            report_score=report_score,
            logger=logger,
        )
    else:
        # TODO: maybe add the option for no guidance too? i.e. upload onmt's original translator
        translator = TranslatorClassifierGuided.from_opt(
            model,
            fields,
            opt,
            model_opt,
            global_scorer=scorer,
            out_file=out_file,
            report_align=opt.report_align,
            report_score=report_score,
            logger=logger
        )
        translator.set_classifier_guidance_config(config, target_classes)
    return translator
    
class FakePredictor(PropertyPredictor):
    def __init__(self, config, alphabet_size):
        super(FakePredictor, self).__init__(config, alphabet_size)
        self.config = config
        model_path = os.path.join(PROJECT_ROOT, "checkpoints", "toy_experiment", config.classifier_guidance.experiment_name, "onmt_model_step_2000.pt")
        checkpoint = torch.load(model_path, map_location='cpu')
        self.vocab = checkpoint['vocab']['src'].base_field.vocab.itos

    def get_true_seq_indices(self):
        if self.config.classifier_guidance.prediction_threshold<3:
            # true_seq = '14
            true_seq = '21'
        else:
            #true_seq = '( 3 * 4 ) + 2'
            true_seq = '5 + ( 6 + 10 )'
        true_seq_indices = [self.vocab.index('<s>')] \
                            + [self.vocab.index(c) for c in true_seq.split(' ')] \
                            + [self.vocab.index('</s>')]
        return true_seq_indices

    def forward(self, x):
        '''
            correct sequence: ( 3 * 4 ) + 2
        '''
        true_seq_indices = self.get_true_seq_indices()
        output = torch.zeros(x.shape[0], 1)
        for i in range(x.shape[0]):
            if str(x[i].tolist())[1:-1] in str(true_seq_indices)[1:-1]:
                output[i] = len(true_seq_indices[1:-1])
            else:
                output[i] = 10000
        return output
    
class ClassifierGuidedBeamSearch(BeamSearch):
    def __init__(self, *args, property_model=None, guidance_scale=1.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.property_model = property_model
        self.guidance_scale = guidance_scale

    def advance(self, log_probs, attn):
        # Add classifier guidance BEFORE beam scoring
        if self.property_model is not None:
            classifier_scores = self._get_classifier_scores(log_probs.shape)
            log_probs = log_probs + self.guidance_scale * classifier_scores
        
        # Call original advance with modified log_probs
        return super().advance(log_probs, attn)
    
class TranslatorClassifierGuided(Translator):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.decode_strategy = None
        self.property_model = None
        self.config = None
        self.target_classes = None
        self.augmented_products = None

    def set_classifier_guidance_config(self, config, target_classes=None):
        '''
        This function is a wrapper around the set_classifier_guidance_config function in onmt's Translator.
        It adds guidance to the model through the translator.
        '''
        self.config = config
        self.target_classes = target_classes
        if config.classifier_guidance.predictor_type == 'neural_network':
            #alphabet_size = get_vocab_size(self.config)
            vocab = get_vocab_from_trained_model(config.classifier_guidance.onmt_checkpoint_path)
            # TODO: override this
            #alphabet_size = alphabet_size + 4
            #print(f'======= alphabet size: {len(vocab)}')
            self.property_model = PropertyPredictor(config, len(vocab))
            checkpoint_path = os.path.join(PROJECT_ROOT,
                                            'checkpoints',
                                            config.classifier_guidance.checkpoint_path)
            #print(f'======= loading property checkpoint from {checkpoint_path}')
            self.property_checkpoint = torch.load(checkpoint_path, map_location=device)
            self.property_model.load_state_dict(self.property_checkpoint['model_state_dict'])
            self.property_model = self.property_model.to(device)

            #self.property_model = FakePredictor(config, alphabet_size)
        else:
            raise ValueError(f"Classifier {config.classifier_guidance.predictor_type} not found")
        self.model.generator = ClassifierGuidedGenerator(original_generator=self.model.generator,
                                                translator=self,
                                                config=self.config,
                                                property_model=self.property_model,
                                                target_classes=target_classes,
                                                guidance_scale=self.config.classifier_guidance.guidance_scale,
                                                n_candidates_to_evaluate=self.config.classifier_guidance.n_candidates_to_evaluate,
                                                debug_classifier_scores=self.config.classifier_guidance.debug_classifier_scores,
                                                are_scores_close_tolerance=self.config.classifier_guidance.are_scores_close_tolerance,
                                                are_scores_close_debug=self.config.classifier_guidance.are_scores_close_debug)


    def adjust_final_beam_scores(
        self,
        translations: List[List[List[str]]],
        batch,
        reaction_types_per_beam: List[List[int]],
        conditional_starting_materials_per_beam: List[List[List[int]]],
        conditional_targets_per_beam: List[List[List[int]]]
    ):
        """Adjust final beam scores using classifier predictions"""
        
        for beam_idx, translations_in_beam in enumerate(translations['predictions']):
            adjusted_scores = []
            
            # Calculate adjusted scores for all sequences in beam
            for seq_idx, pred_seq in enumerate(translations_in_beam):
                full_seq = torch.cat([torch.tensor([self._tgt_bos_idx]).to(device), pred_seq])
                
                with torch.no_grad():
                    if self.config.classifier_guidance.property == 'reaction_type' and reaction_types_per_beam is not None:
                        seq = turn_ids_to_seq(pred_seq, self.config.classifier_guidance.onmt_checkpoint_path)
                        seq = seq.replace('</s>', '')
                        product = self.augmented_products[beam_idx]
                        reaction_smiles = get_sorted_cano_smiles([seq]) + '>>' + Chem.MolToSmiles(Chem.MolFromSmiles(product))
                        reaction_info = get_rxn_insight_info(reaction_smiles)
                        reaction_type = class_to_idx[reaction_info['CLASS']] if reaction_info is not None else -1
                        target_score = int(reaction_type == reaction_types_per_beam[beam_idx].item())
                        
                    elif self.config.classifier_guidance.property == 'tanimoto' and conditional_starting_materials_per_beam is not None:
                        target_score = get_tanimoto(full_seq, conditional_starting_materials_per_beam[beam_idx])
                    else:
                        raise ValueError(f"Property {self.config.classifier_guidance.property} not found")
            
                    # Calculate adjusted score
                    original_score = translations['scores'][beam_idx][seq_idx]
                    adjusted_score = (self.config.classifier_guidance.original_score_weight * original_score + 
                                    self.config.classifier_guidance.adjusted_score_weight * target_score)
                    adjusted_scores.append(adjusted_score)
                    
                    # print(f'Beam {beam_idx}, Seq {seq_idx}: orig_score={original_score:.3f}, '
                    #     f'target_score={target_score}, adjusted_score={adjusted_score:.3f}')
            
            # Get sorted indices
            sorted_indices = sorted(range(len(adjusted_scores)), 
                                key=lambda i: adjusted_scores[i], 
                                reverse=True)
            
            print(f'Beam {beam_idx} ranking change: {list(range(len(sorted_indices)))} -> {sorted_indices}')
            
            # Reorder all fields
            translations['predictions'][beam_idx] = [translations['predictions'][beam_idx][i] for i in sorted_indices]
            translations['scores'][beam_idx] = [adjusted_scores[i] for i in sorted_indices]
            translations['attention'][beam_idx] = [translations['attention'][beam_idx][i] for i in sorted_indices]
            # if isinstance(translations.get('batch'), list):
            #     translations['batch'][beam_idx] = [translations['batch'][beam_idx][i] for i in sorted_indices]
            # if isinstance(translations.get('gold_score'), list):
            #     translations['gold_score'][beam_idx] = [translations['gold_score'][beam_idx][i] for i in sorted_indices]
            if len(translations.get('alignment', [[]])[beam_idx]) > 0:
                translations['alignment'][beam_idx] = [translations['alignment'][beam_idx][i] for i in sorted_indices]

        return translations
        
    def translate_batch(self, batch, src_vocabs, attn_debug):
        """
            Overriding onmt's Translator.translate_batch function to save partially completed sequences
            in self.decode_strategy.alive_seq. This variable is used by ClassifierGuidedGenerator.
        """
        with torch.no_grad():
            if self.sample_from_topk != 0 or self.sample_from_topp != 0:
                self.decode_strategy = GreedySearch(
                    pad=self._tgt_pad_idx,
                    bos=self._tgt_bos_idx,
                    eos=self._tgt_eos_idx,
                    unk=self._tgt_unk_idx,
                    batch_size=batch.batch_size,
                    global_scorer=self.global_scorer,
                    min_length=self.min_length,
                    max_length=self.max_length,
                    block_ngram_repeat=self.block_ngram_repeat,
                    exclusion_tokens=self._exclusion_idxs,
                    return_attention=attn_debug or self.replace_unk,
                    sampling_temp=self.random_sampling_temp,
                    keep_topk=self.sample_from_topk,
                    keep_topp=self.sample_from_topp,
                    beam_size=self.beam_size,
                    ban_unk_token=self.ban_unk_token,
                )
            else:
                # TODO: support these blacklisted features
                assert not self.dump_beam
                self.decode_strategy = BeamSearch(
                    self.beam_size,
                    batch_size=batch.batch_size,
                    pad=self._tgt_pad_idx,
                    bos=self._tgt_bos_idx,
                    eos=self._tgt_eos_idx,
                    unk=self._tgt_unk_idx,
                    n_best=self.n_best,
                    global_scorer=self.global_scorer,
                    min_length=self.min_length,
                    max_length=self.max_length,
                    return_attention=attn_debug or self.replace_unk,
                    block_ngram_repeat=self.block_ngram_repeat,
                    exclusion_tokens=self._exclusion_idxs,
                    stepwise_penalty=self.stepwise_penalty,
                    ratio=self.ratio,
                    ban_unk_token=self.ban_unk_token,
                )
            reps = self.beam_size*self.config.classifier_guidance.n_candidates_to_evaluate
            if self.model.generator.reaction_types is not None:
                reaction_types_per_beam = self.model.generator.reaction_types.clone()
                self.model.generator.batch_indices = batch.indices.repeat_interleave(reps)
                self.model.generator.reaction_types = self.model.generator.reaction_types.repeat_interleave(reps)
            conditional_starting_materials_per_beam = None
            if self.model.generator.conditional_starting_materials is not None:
                conditional_starting_materials_per_beam = self.model.generator.conditional_starting_materials.clone()
                self.model.generator.conditional_starting_materials = self.model.generator.conditional_starting_materials.repeat_interleave(
                    reps,
                    dim=0
                )
            product_smiles_per_beam = None
            if self.model.generator.product_smiles is not None:
                product_smiles_per_beam = self.model.generator.product_smiles.clone()
                self.model.generator.product_smiles = self.model.generator.product_smiles.repeat_interleave(
                    reps,
                    dim=0
                )
            conditional_targets_per_beam = None
            if self.model.generator.conditional_targets is not None:
                conditional_targets_per_beam = self.model.generator.conditional_targets.clone()
                self.model.generator.conditional_targets = self.model.generator.conditional_targets.repeat_interleave(
                    reps,
                    dim=0
                )
            translations = self._translate_batch_with_strategy(
                batch, src_vocabs, self.decode_strategy
            )
            # if self.config.classifier_guidance.readjust_translations:
            #     translations = self.adjust_final_beam_scores(translations, batch, reaction_types_per_beam, conditional_starting_materials_per_beam, conditional_targets_per_beam)
            
            return translations
    
    def _translate_batch_with_strategy(
        self, batch, src_vocabs, decode_strategy
    ):
        """
        This function is a wrapper around the _translate_batch_with_strategy function in onmt's Translator.
        Overriding this just to track the dropped beam indices to apply batched guidance.
        """
        # (0) Prep the components of the search.
        use_src_map = self.copy_attn
        parallel_paths = decode_strategy.parallel_paths  # beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)

        gold_score = self._gold_score(
            batch,
            memory_bank,
            src_lengths,
            src_vocabs,
            use_src_map,
            enc_states,
            batch_size,
            src,
        )

        # (2) prep decode_strategy. Possibly repeat src objects.
        src_map = batch.src_map if use_src_map else None
        target_prefix = batch.tgt if self.tgt_prefix else None
        (
            fn_map_state,
            memory_bank,
            memory_lengths,
            src_map,
        ) = decode_strategy.initialize(
            memory_bank, src_lengths, src_map, target_prefix=target_prefix
        )
        if fn_map_state is not None:
            self.model.decoder.map_state(fn_map_state)

        # (3) Begin decoding step by step:
        for step in range(decode_strategy.max_length):
            decoder_input = decode_strategy.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decode_and_generate(
                decoder_input,
                memory_bank,
                batch,
                src_vocabs,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=decode_strategy.batch_offset,
            )

            decode_strategy.advance(log_probs, attn)
            any_finished = decode_strategy.is_finished.any()
            if any_finished:
                decode_strategy.update_finished()
                if decode_strategy.done:
                    break
            select_indices = decode_strategy.select_indices
            if any_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(
                        x.index_select(1, select_indices) for x in memory_bank
                    )
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

                # MODIFICATION: Update your tracking variables
                if hasattr(self.model.generator, 'batch_indices') and self.model.generator.batch_indices is not None:
                    # NOTE: keep reaction_types unmodified and only changes the batch indices we use to select them
                    reps = self.config.classifier_guidance.n_candidates_to_evaluate
                    self.model.generator.batch_indices = self.model.generator.batch_indices.index_select(0, select_indices).repeat_interleave(reps)
                    
            if parallel_paths > 1 or any_finished:
                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices)
                )
        return self.report_results(
            gold_score,
            batch,
            batch_size,
            src,
            src_lengths,
            src_vocabs,
            use_src_map,
            decode_strategy,
        )
