import time
from typing import List
import random
import warnings
import re
import numpy as np
import multiprocessing
from collections import defaultdict
from typing import Any
import onmt
import torch
#torch.serialization.add_safe_globals([onmt.inputters.text_dataset.TextMultiField])
from syntheseus.reaction_prediction.inference import RootAlignedModel
from syntheseus.reaction_prediction.utils.inference import (
    get_unique_file_in_dir,
)
from syntheseus import Molecule
from rdkit import Chem
from multiguide.dataset.helpers import smi_tokenizer
from multiguide.onmt.guided_translator import build_classifier_guided_translator
from syntheseus.reaction_prediction.utils.inference import (
    get_unique_file_in_dir,
    process_raw_smiles_outputs_backwards,
)

def patched_canonicalize(smiles, return_max_frag=True, opt=None):
    mol = Chem.MolFromSmiles(smiles,sanitize=not opt.synthon)
    if mol is not None:
        for bond in mol.GetBonds():
            if bond.GetBondType() == Chem.BondType.DATIVE:
                bond.SetBondType(Chem.BondType.SINGLE)

        [atom.ClearProp('molAtomMapNumber') for atom in mol.GetAtoms() if atom.HasProp('molAtomMapNumber')]
        try:
            smi = Chem.MolToSmiles(mol, isomericSmiles=True)
        except:
            if return_max_frag:
                return '',''
            else:
                return ''
        if return_max_frag:
            sub_smi = smi.split(".")
            sub_mol = [Chem.MolFromSmiles(smiles,sanitize=not opt.synthon) for smiles in sub_smi]
            sub_mol_size = [(sub_smi[i], len(m.GetAtoms())) for i, m in enumerate(sub_mol) if m is not None]
            if len(sub_mol_size) > 0:
                return smi, patched_canonicalize(sorted(sub_mol_size,key=lambda x:x[1],reverse=True)[0][0],return_max_frag=False, opt=opt)
            else:
                return smi, ''
        else:
            return smi
    else:
        if return_max_frag:
            return '',''
        else:
            return ''

class RootAlignedForwardModel(RootAlignedModel):
    def __init__(self, config=None, product_smiles=None, *args, **kwargs):
        '''
        This class is a wrapper around the RootAlignedModel class in syntheseus.
        It adds guidance to the model by overriding the translator.
        '''
        super().__init__(*args, **kwargs)
        # HACK: using the same code as in syntheseus's RootAlignedModel to obtain the opt for onmt's translator
        self.load_opt()
        # NOTE: actually there seems to be a bug in syntheseus's RootAlignedModel
        # the score.py module does not get opt set properly, so we need to patch it here as a temp fix
        from root_aligned import score
        from functools import partial
        # Replace the module's function with a partial that includes self.opt
        score.canonicalize_smiles_clear_map = partial(patched_canonicalize, opt=self.opt)
        # NOTE: we introduce guidance by overriding the translator

    def get_root_id(self, mol, root_map_number):
        root = -1
        for i, atom in enumerate(mol.GetAtoms()):
            if atom.GetAtomMapNum() == root_map_number:
                root = i
                break
        return root
    
    def clear_map_canonical_smiles(self, smi, canonical=True, root=-1):
        mol = Chem.MolFromSmiles(smi)
        if mol is not None:
            for bond in mol.GetBonds():
                if bond.GetBondType() == Chem.BondType.DATIVE:
                    bond.SetBondType(Chem.BondType.SINGLE)
            for atom in mol.GetAtoms():
                if atom.HasProp('molAtomMapNumber'):
                    atom.ClearProp('molAtomMapNumber')
            return Chem.MolToSmiles(mol, isomericSmiles=True, rootedAtAtom=root, canonical=canonical)
        else:
            return smi
    
    def get_cano_map_number(self, smi, root=-1):
        atommap_mol = Chem.MolFromSmiles(smi)
        canonical_mol = Chem.MolFromSmiles(self.clear_map_canonical_smiles(smi,root=root))
        cano2atommapIdx = atommap_mol.GetSubstructMatch(canonical_mol)
        correct_mapped = [canonical_mol.GetAtomWithIdx(i).GetSymbol() == atommap_mol.GetAtomWithIdx(index).GetSymbol() for i,index in enumerate(cano2atommapIdx)]
        atom_number = len(canonical_mol.GetAtoms())
        if np.sum(correct_mapped) < atom_number or len(cano2atommapIdx) < atom_number:
            cano2atommapIdx = [0] * atom_number
            atommap2canoIdx = canonical_mol.GetSubstructMatch(atommap_mol)
            if len(atommap2canoIdx) != atom_number:
                return None
            for i, index in enumerate(atommap2canoIdx):
                cano2atommapIdx[index] = i
        id2atommap = [atom.GetAtomMapNum() for atom in atommap_mol.GetAtoms()]

        return [id2atommap[cano2atommapIdx[i]] for i in range(atom_number)]
    
    def _get_reactions(
        self, inputs: List[Molecule], num_results: int, random_augmentation: bool = True
    ):
        '''
        This function is a wrapper around the _get_reactions function in syntheseus's RootAlignedModel.
        It adds guidance to the model by using the translator.
        '''
        # Step 1: Perform data augmentation.
        augmented_inputs = []
        #product = product.split(".")
        for input in inputs:
            reactant = input.smiles.split(".")
            #rea_atom_map_numbers = [list(map(int, re.findall(r"(?<=:)\d+", rea))) for rea in reactant]
            rea_atom_map_numbers = [[i + 1 for i in range(Chem.MolFromSmiles(rea).GetNumAtoms())] for rea in reactant]
            max_times = np.prod([len(map_numbers) for map_numbers in rea_atom_map_numbers])
            times = min(self.num_augmentations, max_times)
            reactant_roots = [[-1 for _ in reactant]]
            j = 0
            while j < times:
                reactant_roots.append([random.sample(rea_atom_map_numbers[k], 1)[0] for k in range(len(reactant))])
                if reactant_roots[-1] in reactant_roots[:-1]:
                    reactant_roots.pop()
                else:
                    j += 1
            if j < self.num_augmentations:
                reactant_roots.extend(random.choices(reactant_roots, k=self.num_augmentations - times))
                times = self.num_augmentations

            assert times == self.num_augmentations

            for k in range(times):
                tmp = list(zip(reactant, reactant_roots[k]))
                random.shuffle(tmp)
                reactant_k, reactant_roots_k = [i[0] for i in tmp], [i[1] for i in tmp]
                aligned_reactants = []
                for i, rea in enumerate(reactant_k):
                    rea_root_atom_index = reactant_roots_k[i]
                    if rea_root_atom_index <= 0:
                        rea_root = -1
                    else:
                        rea_root = rea_root_atom_index - 1
                    rea_smi = self.clear_map_canonical_smiles(rea, canonical=True, root=rea_root)
                    aligned_reactants.append(rea_smi)

                rea_smi = ".".join(aligned_reactants)
                randomized_mol = Molecule(smiles=rea_smi, canonicalize=False)
                augmented_inputs.append(randomized_mol)

        assert len(augmented_inputs) == len(inputs) * self.num_augmentations

        # Step 2: Map from `Molecule`s to SMILES bytes to align with `root_aligned/OpenNMT.py`.
        augmented_batch = self._mols_to_batch(augmented_inputs)

        # Step 3: Translate.
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message="__floordiv__ is deprecated")

            _, augmented_predictions = self.translator.translate(
                src=augmented_batch,
                src_feats=defaultdict(list),
                tgt=None,
                batch_size=2048,
                batch_type="tokens",
                attn_debug=False,
                align_debug=False,
            )  # shape: `[data_size x augmentation_size, beam_size]`

        # Step 4: Unravel and canonicalize.
        lines = []  # shape: `[data_size x augmentation_size x beam_size]`
        for i in range(len(augmented_predictions)):
            for j in range(len(augmented_predictions[i])):
                lines.append(augmented_predictions[i][j].replace(" ", ""))

        from root_aligned.score import canonicalize_smiles_clear_map

        raw_predictions = []
        pool = multiprocessing.Pool(multiprocessing.cpu_count())
        raw_predictions = pool.map(
            func=canonicalize_smiles_clear_map, iterable=lines
        )  # Canonicalize reactants and modify illegal reactants into empty strings.
        pool.close()
        pool.join()

        # From `[data_size x augmentation_size x beam_size]` to `[data_size, augmentation_size, beam_size]`.
        predictions: List[List[Any]] = [
            [[] for _ in range(self.num_augmentations)] for _ in range(len(inputs))
        ]

        for i, prediction in enumerate(raw_predictions):
            predictions[i // (self.beam_size * self.num_augmentations)][
                i % (self.beam_size * self.num_augmentations) // self.beam_size
            ].append(prediction)

        # Step 5: Rank legal reactants from all augmentations and beams.
        ranked_results = []  # shape: `[data_size, augmentation_size x beam_size]`
        ranked_scores = []

        from root_aligned.score import compute_rank

        for i in range(len(predictions)):
            rank, _ = compute_rank(predictions[i])
            rank = list(zip(rank.keys(), rank.values()))
            rank.sort(key=lambda x: x[1], reverse=True)
            rank = rank[:num_results]  # Truncate to `num_results` results.
            ranked_results.append([item[0][0] for item in rank])  # Output reactant SMILES.
            ranked_scores.append([item[1] for item in rank])  # Output scores used for ranking.

        return [
            process_raw_smiles_outputs_backwards(
                input, outputs, self._build_kwargs_from_scores(scores)
            )
            for input, outputs, scores in zip(inputs, ranked_results, ranked_scores)
        ]

    def load_opt(self):
        '''
        This function is a wrapper around the load_opt function in syntheseus's RootAlignedModel.
        It loads the opt from the model directory.
        '''
        import torch
        import yaml
        import argparse
        #Parse arguments for calling external functions from `root_aligned/OpenNMT.py`
        config_file_path = get_unique_file_in_dir(self.model_dir, pattern="*.yml")
        print(f'======= config_file_path: {config_file_path}')
        with open(config_file_path, "r") as f:
            opt_from_config = yaml.safe_load(f)
        opt = argparse.Namespace()
        for key, value in opt_from_config.items():
            setattr(opt, key, value)
        opt.models = [get_unique_file_in_dir(self.model_dir, pattern="*.pt")]
        print(f'======= opt.models: {opt.models}')
        opt.output = "/dev/null"
        print(f'========= self.device: {self.device}')
        opt.gpu = -1 if self.device == "cpu" else torch.device(self.device).index
        print(f'========= opt.gpu: {opt.gpu}')
        setattr(opt, "synthon", False)

        self.opt = opt