# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_tokenization.ipynb (unless otherwise specified).

__all__ = ['get_default_tokenizer', 'SmilesTokenizer', 'RegexTokenizer', 'SMI_REGEX_PATTERN',
           'NotCanonicalizableSmilesException', 'canonicalize_smi', 'process_reaction']

# Cell
import collections
import logging
import os
import re
import numpy as np
from rdkit import Chem
from tqdm import tqdm
import pkg_resources

from typing import List

from transformers import BertTokenizer
atoms = [
    "In","Rb","Sr","Pt","Sn",
    "Tc","Mn","Ce","Bi","Cd",
    "Re","Zn","Os","Ge","Ta",
    "Si","Tb","Ca","Pd","Sc",
    "Cu","Ar","Ga","Xe","Ba",
    "Fe","Sm","Gd","Zr","Eu",
    "Sb","Au","No","Er","Mo",
    "Be","La","Ac","Se","Br",
    "Nb","Co","Ir","Cs","Dy",
    "Lu","Th","Na","As","Pb",
    "Yb","He","Ru","Nd","Ra",
    "Ti","At","Te","Ag","Mg",
    "Al","Hg","Rh","Ni","Cl",
    "Cr","Hf","Pr","Li","Tl",
    "H","P","U","F","O","V",
    "W","S","K","Y","B","I",
    "H","C","N",
]
# Cell
element_pattern = '|'.join(atoms)
SMI_REGEX_PATTERN = rf'(\[[^\]]+\]|{element_pattern}|@@|@|\+\d*|\-\d*|=|#|/|\\|%|\d|[a-zA-Z]|\(|\))'
def get_default_tokenizer():
    default_vocab_path = (
        pkg_resources.resource_filename(
                    "rxnfp",
                    "models/transformers/bert_ft_10k_25s/vocab.txt"
                )
    )
    return SmilesTokenizer(default_vocab_path, do_lower_case=False)


class SmilesTokenizer(BertTokenizer):
    """
    Constructs a SmilesBertTokenizer.
    Adapted from https://github.com/huggingface/transformers
    and https://github.com/rxn4chemistry/rxnfp.

    Args:
        vocabulary_file: path to a token per line vocabulary file.
    """

    def __init__(
        self,
        vocab_file: str,
        unk_token: str = "[UNK]",
        sep_token: str = "[SEP]",
        pad_token: str = "[PAD]",
        cls_token: str = "[CLS]",
        mask_token: str = "[MASK]",
        do_lower_case=False,
        **kwargs,
    ) -> None:
        """Constructs an SmilesTokenizer.
        Args:
            vocabulary_file: vocabulary file containing tokens.
            unk_token: unknown token. Defaults to "[UNK]".
            sep_token: separator token. Defaults to "[SEP]".
            pad_token: pad token. Defaults to "[PAD]".
            cls_token: cls token. Defaults to "[CLS]".
            mask_token: mask token. Defaults to "[MASK]".
        """
        super().__init__(
            vocab_file=vocab_file,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            do_lower_case=do_lower_case,
            **kwargs,
        )
        # define tokenization utilities
        self.tokenizer = RegexTokenizer()

    @property
    def vocab_list(self) -> List[str]:
        """List vocabulary tokens.
        Returns:
            a list of vocabulary tokens.
        """
        return list(self.vocab.keys())

    def _tokenize(self, text: str) -> List[str]:
        """Tokenize a text representing an enzymatic reaction with AA sequence information.
        Args:
            text: text to tokenize.
        Returns:
            extracted tokens.
        """
        return self.tokenizer.tokenize(text)


class RegexTokenizer:
    """Run regex tokenization"""

    def __init__(self, regex_pattern: str=SMI_REGEX_PATTERN) -> None:
        """Constructs a RegexTokenizer.
        Args:
            regex_pattern: regex pattern used for tokenization.
            suffix: optional suffix for the tokens. Defaults to "".
        """
        self.regex_pattern = regex_pattern
        self.regex = re.compile(self.regex_pattern)

    def tokenize(self, text: str) -> List[str]:
        """Regex tokenization.
        Args:
            text: text to tokenize.
        Returns:
            extracted tokens separated by spaces.
        """
        tokens = [token for token in self.regex.findall(text)]
        return tokens


# Cell
class NotCanonicalizableSmilesException(ValueError):
    pass


def canonicalize_smi(smi, remove_atom_mapping=False):
    r"""
    Canonicalize SMILES
    """
    mol = Chem.MolFromSmiles(smi)
    if not mol:
        raise NotCanonicalizableSmilesException("Molecule not canonicalizable")
    if remove_atom_mapping:
        for atom in mol.GetAtoms():
            if atom.HasProp("molAtomMapNumber"):
                atom.ClearProp("molAtomMapNumber")
    return Chem.MolToSmiles(mol)


def process_reaction(rxn):
    """
    Process and canonicalize reaction SMILES
    """
    reactants, reagents, products = rxn.split(">")
    try:
        precursors = [canonicalize_smi(r, True) for r in reactants.split(".")]
        if len(reagents) > 0:
            precursors += [
                canonicalize_smi(r, True) for r in reagents.split(".")
            ]
        products = [canonicalize_smi(p, True) for p in products.split(".")]
    except NotCanonicalizableSmilesException:
        return ""

    joined_precursors = ".".join(sorted(precursors))
    joined_products = ".".join(sorted(products))
    return f"{joined_precursors}>>{joined_products}"

lines = []
with open("reaction_smiles_type_pretrain_eval.txt","r") as f:
    lines += f.readlines()
with open("reaction_smiles_type_pretrain_train.txt","r") as f:
    lines += f.readlines()
lines = [line.strip() for line in lines]

tokenizer = RegexTokenizer()
vocab = set()
for line in tqdm(lines):
    smiles = line


# #     rea,pro = smiles.split(">>")
# #     mol = Chem.MolFromSmiles(rea)
# #     for atom in mol.GetAtoms():
# #         vocab.add(atom.GetSymbol())
# #     mol = Chem.MolFromSmiles(pro)
# #     for atom in mol.GetAtoms():
# #         vocab.add(atom.GetSymbol())
# # with open('atoms.txt','w') as f:
# #     for item in vocab:
# #         f.write(f'"{item}",')
    


    result = tokenizer.tokenize(smiles)
    vocab.update(result)
with open('vocab.txt','w') as f:
    for item in vocab:
        f.write(f"{item}\n")