import time
import torch
import random
import warnings
import multiprocessing
from collections import defaultdict
from functools import partial
import torch.nn.functional as F
from typing import List, Optional, Sequence, TypeVar, Any
from syntheseus.reaction_prediction.inference import RootAlignedModel
from syntheseus.reaction_prediction.utils.inference import (
    get_unique_file_in_dir,
)
from syntheseus import Molecule
from syntheseus.interface.models import (
    BackwardReactionModel,
    ForwardReactionModel,
    InputType,
    ReactionModel,
    ReactionType,
)
from syntheseus.interface.reaction import SingleProductReaction
from syntheseus.reaction_prediction.utils.inference import (
    get_unique_file_in_dir,
    process_raw_smiles_outputs_backwards,
)

from rdkit import Chem
from multiguide.onmt.guided_translator import build_classifier_guided_translator
#from multiguide.training.helpers import set_product_classifier
from multiguide.dataset.helpers import turn_seq_to_ids, compare_reactant_smiles, get_vocab_from_trained_model
from multiguide.dataset.helpers import get_tanimoto, class_to_idx, get_rxn_insight_info

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def patched_canonicalize(smiles, return_max_frag=True, synthon=None):
    #print(f'========= in patched_canonicalize')
    #print(f'smiles {smiles}')
    if '<' in smiles or '>' in smiles:
        print(f'Found < or > in smiles: {smiles}')
    mol = Chem.MolFromSmiles(smiles,sanitize=not synthon)
    # # Remove dative bonds by converting them back to single bonds
    if mol is not None:
        for bond in mol.GetBonds():
            if bond.GetBondType() == Chem.BondType.DATIVE:
                bond.SetBondType(Chem.BondType.SINGLE)
        [atom.ClearProp('molAtomMapNumber') for atom in mol.GetAtoms() if atom.HasProp('molAtomMapNumber')]
        try:
            smi = Chem.MolToSmiles(mol, isomericSmiles=True)
            #print(f'========= returning smi: {smi}')
            if '<' in smiles or '>' in smiles:
                raise ValueError(f'Found < or > in smiles: {smiles}')
        except:
            if return_max_frag:
                return '',''
            else:
                return ''
        if return_max_frag:
            sub_smi = smi.split(".")
            sub_mol = [Chem.MolFromSmiles(smiles,sanitize=not synthon) for smiles in sub_smi]
            sub_mol_size = [(sub_smi[i], len(m.GetAtoms())) for i, m in enumerate(sub_mol) if m is not None]
            if len(sub_mol_size) > 0:
                return smi, patched_canonicalize(sorted(sub_mol_size,key=lambda x:x[1],reverse=True)[0][0],return_max_frag=False, synthon=synthon)
            else:
                return smi, ''
        else:
            return smi
    else:
        if return_max_frag:
            return '',''
        else:
            return ''

class RootAlignedFixedModel(RootAlignedModel):
    def __init__(
        self, config=None, target_classes=None, conditional_starting_materials=None, conditional_targets=None, *args, **kwargs
    ):
        '''
        This class is a wrapper around the RootAlignedModel class in syntheseus.
        It adds guidance to the model by overriding the translator.
        '''
        super().__init__(*args, **kwargs)
        # HACK: using the same code as in syntheseus's RootAlignedModel to obtain the opt for onmt's translator
        self.config = config
        self.load_opt()
        self.node_depth = None
        self.reaction_types = None
        self.conditional_starting_materials = conditional_starting_materials
        self.conditional_targets = conditional_targets
        self.vocab = get_vocab_from_trained_model(self.config.classifier_guidance.onmt_checkpoint_path)
        # NOTE: actually there seems to be a bug in syntheseus's RootAlignedModel
        # the score.py module does not get opt set properly, so we need to patch it here as a temp fix
        from root_aligned import score
        from functools import partial
        # Replace the module's function with a partial that includes self.opt
        score.canonicalize_smiles_clear_map = partial(patched_canonicalize, synthon=self.opt.synthon)
        # NOTE: we introduce guidance by overriding the translator
        #print(f'======= classifier_guidance_config: {config.classifier_guidance}')
        if config.search.steered: # overrides translator
            self.translator = build_classifier_guided_translator(opt=self.opt,
                                                                 report_score=False,
                                                                 target_classes=target_classes,
                                                                 config=config)
            
    def __call__(
        self, inputs: list[InputType], node_depth: list[int] = None, num_results: Optional[int] = None,
        reaction_types: list[int] = None, 
        conditional_starting_materials: list[str] = None,
        conditional_targets: list[str] = None,
    ) -> list[Sequence[ReactionType]]:
        """Given a batch of inputs to the reaction model, return a batch of results.

        Args:
            inputs: Batch of inputs to the reaction model, each either a molecule or a set of
                molecules, depending on directionality.
            num_results: Number of results to return for each input in the batch. Many models may
                only be able to produce a finite number of candidate outputs, thus the returned
                lists are allowed to be shorter than `num_results`. If not provided, the default
                number of results will be used.
        """
        self.node_depth = node_depth
        self.reaction_types = reaction_types
        self.conditional_starting_materials = conditional_starting_materials
        self.conditional_targets = conditional_targets
        self.product_smiles = inputs
        #self._assign_class_with_product_classifier_if_applicable(inputs)
        return super().__call__(inputs, num_results)
        # # Step 0: set num_results to default if not provided
        # num_results = num_results or self.default_num_results

        # # Step 1: call underlying model for all inputs not in the cache,
        # # and add them to the cache
        # inputs_not_in_cache = list({inp for inp in inputs if (inp, num_results) not in self._cache})
        # if len(inputs_not_in_cache) > 0:
        #     new_rxns = self._get_reactions(inputs=inputs_not_in_cache, num_results=num_results)
        #     assert len(new_rxns) == len(inputs_not_in_cache)
        #     for inp, rxns in zip(inputs_not_in_cache, new_rxns):
        #         self._cache[(inp, num_results)] = self.filter_reactions(rxns)

        # # Step 2: all reactions should now be in the cache,
        # # so the output can just be assembled from there.
        # # Clear the cache if use_cache=False
        # output = [self._cache[(inp, num_results)] for inp in inputs]
        # if not self._use_cache:
        #     self._cache.clear()

        # # Step 3: increment counts
        # self._num_cache_misses += len(inputs_not_in_cache)
        # self._num_cache_hits += len(inputs) - len(inputs_not_in_cache)

        # return output
    
    def _assign_class_with_product_classifier_if_applicable(self, inputs):
        if not self.config.classifier_guidance.as_regression and self.config.classifier_guidance.with_product_classifier:
            # init the product classifier
            assert len(inputs) == 1, 'Product classifier only supports one input'
            product_classifier = set_product_classifier(self.config)
            product_classifier.eval()
            with torch.no_grad():
                product_classifier.to(device)
                input_ids = turn_seq_to_ids(self.config, inputs[0].smiles)
                input_ids = input_ids.unsqueeze(0).to(device)
                classifier_scores = product_classifier(input_ids)
                confidence, output = F.softmax(classifier_scores, dim=1).max(dim=1)

            # run the product classifier on the target product
            # return the class with the highest confidence
            # check how often this function is called (should be once per product)
            return output

    def _get_reactions(
        self, inputs: List[Molecule], num_results: int
    ):
        '''
        This function is a wrapper around the _get_reactions function in syntheseus's RootAlignedModel.
        It adds guidance to the model by using the translator.
        '''
        start_time = time.time()
        self.translator.model.generator.node_depth = self.node_depth
        if self.reaction_types is not None:
            self.translator.model.generator.reaction_types = self.reaction_types.repeat_interleave(self.num_augmentations)
        else:
            self.translator.model.generator.reaction_types = None
        self.translator.model.generator.immediate_target = inputs[0].smiles
        # NOTE: turn to sequence here because the info is deduplicated
        if self.conditional_starting_materials is not None:
            conditional_starting_materials_list = [turn_seq_to_ids(
                s, 
                onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path
            ) for s in self.conditional_starting_materials]
            # pad the sequences
            conditional_starting_materials = torch.nn.utils.rnn.pad_sequence(
                conditional_starting_materials_list,
                batch_first=True,
                padding_value=self.vocab.index('<blank>')
                # pad along the sequence length dimension
            )
            self.translator.model.generator.conditional_starting_materials = conditional_starting_materials.repeat_interleave(
                self.num_augmentations,
                dim=0
            )
        if self.conditional_targets is not None:
            conditional_targets_list = [turn_seq_to_ids(
                t, 
                onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path
            ) for t in self.conditional_targets]
            # pad the sequences
            conditional_targets = torch.nn.utils.rnn.pad_sequence(
                conditional_targets_list,
                batch_first=True,
                padding_value=self.vocab.index('<blank>')
                # pad along the sequence length dimension
            )
            self.translator.model.generator.conditional_targets = conditional_targets.repeat_interleave(
                self.num_augmentations,
                dim=0
            )
        if self.product_smiles is not None:
            product_smis = ['>>'+m.smiles for m in self.product_smiles]
            product_smis_list = [turn_seq_to_ids(
                p, 
                onmt_checkpoint_path=self.config.classifier_guidance.onmt_checkpoint_path,
                use_unk=True
            ) for p in product_smis]
            product_smis = torch.nn.utils.rnn.pad_sequence(
                product_smis_list,
                batch_first=True,
                padding_value=self.vocab.index('<blank>')
            )
            self.translator.model.generator.product_smiles = product_smis.repeat_interleave(
                self.num_augmentations,
                dim=0
            )
        self.translator.model.generator.keep_using_conditiong = True
        # add product predictor here
        print(f'======= device: {device}')
        if device.type == 'cuda':
            torch.cuda.reset_peak_memory_stats(device=device)
        #reactions = super()._get_reactions(inputs, num_results)
        reactions = self._get_reactions_original(inputs, num_results)
        # TODO: maybe print the reactions here? or evaluate them somehow to see if the guidance applies
        print(f'======= _get_reactions time: {time.time() - start_time} seconds')
        if device.type == 'cuda':
            peak_mb = torch.cuda.max_memory_allocated(device=device) / 1024**2
            total_mb = torch.cuda.get_device_properties(0).total_memory / 1024**2
            print(f"\n=== PEAK MEMORY FOR ONE PRODUCT ===")
            print(f"Peak: {peak_mb:.0f}MB / {total_mb:.0f}MB ({100*peak_mb/total_mb:.1f}%)")
            print(f"Estimated batch size: {int(0.8 * total_mb / peak_mb)} products")
        if len(reactions) > 0 and len(reactions[0]) > 0:
            print(f'First result (out of {len(reactions[0])}): {reactions[0][0].reactants}')
        else:
            print(f'No reactions found')
        return reactions

    def _get_reactions_original(
        self, inputs, num_results: int, random_augmentation=False
    ) -> List[Sequence[SingleProductReaction]]:
        # Step 1: Perform data augmentation.
        augmented_inputs = []
        if random_augmentation:
            for input in inputs:
                augmented_inputs.append(input)
                for i in range(self.num_augmentations - 1):
                    randomized_smi = Chem.MolToSmiles(input.rdkit_mol, doRandom=True)
                    randomized_mol = Molecule(smiles=randomized_smi, canonicalize=False)
                    augmented_inputs.append(randomized_mol)
        else:
            from root_aligned.preprocessing.generate_PtoR_data import clear_map_canonical_smiles

            for input in inputs:
                product_atom_map_numbers = [i + 1 for i in range(input.rdkit_mol.GetNumAtoms())]
                max_times = len(product_atom_map_numbers)
                product_roots = [-1]
                times = min(self.num_augmentations, max_times)
                if times < self.num_augmentations:  # times = max_times
                    product_roots.extend(product_atom_map_numbers)
                    product_roots.extend(
                        random.choices(product_roots, k=self.num_augmentations - len(product_roots))
                    )
                else:  # times = num_augmentations
                    while len(product_roots) < times:
                        product_roots.append(random.sample(product_atom_map_numbers, 1)[0])
                        if product_roots[-1] in product_roots[:-1]:
                            product_roots.pop()
                times = len(product_roots)
                assert times == self.num_augmentations
                for k in range(times):
                    pro_root_atom_map = product_roots[k]
                    pro_root = pro_root_atom_map - 1
                    if pro_root_atom_map <= 0:
                        pro_root = -1
                    pro_smi = clear_map_canonical_smiles(
                        input.smiles, canonical=True, root=pro_root
                    )
                    randomized_mol = Molecule(smiles=pro_smi, canonicalize=False)
                    augmented_inputs.append(randomized_mol)

        assert len(augmented_inputs) == len(inputs) * self.num_augmentations

        # Step 2: Map from `Molecule`s to SMILES bytes to align with `root_aligned/OpenNMT.py`.
        augmented_batch = self._mols_to_batch(augmented_inputs)
        augmented_products = [input.smiles for input in augmented_inputs]
        self.translator.augmented_products = augmented_products
        
        # Step 3: Translate.
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message="__floordiv__ is deprecated")

            _, augmented_predictions = self.translator.translate(
                src=augmented_batch,
                src_feats=defaultdict(list),
                tgt=None,
                batch_size=100000, # 2048, 4096, 8192, 16384
                batch_type="tokens", # tokens
                attn_debug=False,
                align_debug=False,
            )  # shape: `[data_size x augmentation_size, beam_size]`
        # print(f'got predictions in guided rsmiles')
        # # ADD THESE:
        # torch.cuda.synchronize()  # Force GPU to finish and surface errors
        # print(f"Predictions type: {type(augmented_predictions)}")
        # print(f"First prediction sample: {augmented_predictions[0][:3] if augmented_predictions else 'EMPTY'}")

        # Check for garbage in predictions
        # for i, pred_list in enumerate(augmented_predictions[:5]):
        #     for j, pred in enumerate(pred_list[:3]):
        #         if any(c in pred for c in '<>[]'):
        #             print(f"WARNING: Invalid chars in pred[{i}][{j}]: {pred}")
                    
        # Step 4: Unravel and canonicalize.
        lines = []  # shape: `[data_size x augmentation_size x beam_size]`
        for i in range(len(augmented_predictions)):
            for j in range(len(augmented_predictions[i])):
                lines.append(augmented_predictions[i][j].replace(" ", ""))
        print('got lines in guided rsmiles')
        #print(f'lines {lines}')
        print(f'len(lines) {len(lines)}')
        #print(f'cpu count: {multiprocessing.cpu_count()}')

        from root_aligned.score import canonicalize_smiles_clear_map

        raw_predictions = []
        # for smi in lines:
        #     raw_predictions.append(canonicalize_smiles_clear_map(smi, synthon=False))
        # print('returned from canonicalize_smiles_clear_map')

        # In the loop, add timeout/debugging:
        # for i, smi in enumerate(lines):
        #     print(f"Processing {i}: {smi[:50]}")
        #     try:
        #         result = canonicalize_smiles_clear_map(smi, synthon=False)
        #         print(f"  -> Success")
        #         raw_predictions.append(result)
        #     except Exception as e:
        #         print(f"  -> Error: {e}")
        #         raw_predictions.append(('', ''))        
        pool = multiprocessing.Pool(multiprocessing.cpu_count())

        # NOTE: this is modified by Najwa, added synthon False
        # NOTE: this is where the non parsable smiles are removed
        canonicalize_func = partial(canonicalize_smiles_clear_map, synthon=False)
        raw_predictions = pool.map(
            func=canonicalize_func, iterable=lines
        )  # Canonicalize reactants and modify illegal reactants into empty strings.
        print('done cano in multiprocessing pool')
        pool.close()
        pool.join()
        print('done cano in multiprocessing pool')
        # From `[data_size x augmentation_size x beam_size]` to `[data_size, augmentation_size, beam_size]`.
        predictions: List[List[Any]] = [
            [[] for _ in range(self.num_augmentations)] for _ in range(len(inputs))
        ]

        for i, prediction in enumerate(raw_predictions):
            predictions[i // (self.beam_size * self.num_augmentations)][
                i % (self.beam_size * self.num_augmentations) // self.beam_size
            ].append(prediction)
        # Step 5: Rank legal reactants from all augmentations and beams.
        ranked_results = []  # shape: `[data_size, augmentation_size x beam_size]`
        ranked_scores = []

        from root_aligned.score import compute_rank

        for i, prediction in enumerate(predictions):
            # NOTE: and compute rank is where the deduplication is done
            rank, _ = compute_rank(prediction)
            rank = list(zip(rank.keys(), rank.values()))
            rank.sort(key=lambda x: x[1], reverse=True)
            rank = rank[:num_results]  # Truncate to `num_results` results.
            ranked_results.append([item[0][0] for item in rank])  # Output reactant SMILES.
            ranked_scores.append([item[1] for item in rank])  # Output scores used for ranking.
        
        # print if ground truth reactants are in the ranked results + their rank
        if self.config.classifier_guidance.debug_print_ground_truth_reactants:
            for input, outputs, scores in zip(inputs, ranked_results, ranked_scores):
                ground_truth_reactants = input.metadata['ground_truth_reactants']
                starting_material = input.metadata['starting_material']
                print(f'ground_truth_reactants: {ground_truth_reactants}')
                ground_truth_in_output = [(output, index) for index, output in enumerate(outputs) if compare_reactant_smiles(ground_truth_reactants, output)]
                #ground_truth_in_output = [(output, index, get_tanimoto(starting_material, output)) for index, output in enumerate(outputs) if compare_reactant_smiles(starting_material, output)]
                print(f'ground_truth_in_output {ground_truth_in_output}')
                if len(ground_truth_in_output) > 0:
                    print('ground truth reactants in output:')
                    index = ground_truth_in_output[0][1]
                    print('='*5, f'at index {index} out of {len(outputs)}')
                    print('='*5, f'score: {scores[index]}')
                
            ranked_results_with_tanimoto = []
            for input, outputs, scores in zip(inputs, ranked_results, ranked_scores):
                starting_material = input.metadata['starting_material']
                ground_truth_in_output = []
                for index, output in enumerate(outputs):
                    reactants_with_tanimoto = []
                    for r_index, r in enumerate(output.split('.')):
                        tanimoto = get_tanimoto(starting_material, r)
                        reactants_with_tanimoto.append((r, r_index, tanimoto))
                    reactants_with_tanimoto.sort(key=lambda x: x[2], reverse=True)
                    ground_truth_in_output.append((reactants_with_tanimoto, index))
                ground_truth_in_output.sort(key=lambda x: x[0][0][2], reverse=True)
                ranked_results_with_tanimoto.append(ground_truth_in_output)
            print(f'======= ranked_results_with_tanimoto: {ranked_results_with_tanimoto}')

            # compute the reaction type of each output
            # results_with_reaction_type = []
            # total_matches = 0
            # idx_of_first_match = -1
            # for input, outputs, scores in zip(inputs, ranked_results, ranked_scores):
            #     for idx, output in enumerate(outputs):
            #         reaction = output + '>>' + input.smiles
            #         reaction_info = get_rxn_insight_info(reaction)
            #         if reaction_info is not None:
            #             reaction_type = class_to_idx[reaction_info['CLASS']]
            #         else:
            #             reaction_type = -1
            #         #results_with_reaction_type.append((output, reaction_type))
            #         print(f'output: {output}, reaction_type: {reaction_type}')
            #         total_matches += (reaction_type==2)
            #         if reaction_type==2 and idx_of_first_match==-1:
            #             idx_of_first_match = idx
            # print(f'total matches: {total_matches}')
            # print(f'ratio: {total_matches / len(outputs)}')
            # print(f'idx of first match: {idx_of_first_match}'
            #print(f'======= ranked_results: {ranked_results}')
            #print(f'ground truth reactants: {[input.metadata["ground_truth_reactants"] for input in inputs]}')
            #exit()
        return [
            process_raw_smiles_outputs_backwards(
                input, outputs, self._build_kwargs_from_scores(scores)
            )
            for input, outputs, scores in zip(inputs, ranked_results, ranked_scores)
        ]
        
    def load_opt(self):
        '''
        This function is a wrapper around the load_opt function in syntheseus's RootAlignedModel.
        It loads the opt from the model directory.
        '''
        import torch
        import yaml
        import argparse
        #Parse arguments for calling external functions from `root_aligned/OpenNMT.py`
        config_file_path = get_unique_file_in_dir(self.model_dir, pattern="*.yml")
        print(f'======= config_file_path: {config_file_path}')
        with open(config_file_path, "r") as f:
            opt_from_config = yaml.safe_load(f)
        opt = argparse.Namespace()
        for key, value in opt_from_config.items():
            setattr(opt, key, value)
        opt.models = [get_unique_file_in_dir(self.model_dir, pattern="*.pt")]
        print(f'======= opt.models: {opt.models}')
        opt.output = "/dev/null"
        print(f'========= self.device: {self.device}')
        opt.gpu = -1 if self.device == "cpu" else torch.device(self.device).index
        print(f'========= opt.gpu: {opt.gpu}')
        setattr(opt, "synthon", False)

        self.opt = opt