import re

import torch
import torch.nn.functional as F
from rdkit import Chem, RDLogger
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence

RDLogger.DisableLog("rdApp.*")  # type: ignore

# fmt: off
ELEMENTS = [
    "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg",
    "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr",
    "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr",
    "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd",
    "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd",
    "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf",
    "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po",
    "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm",
    "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs",
    "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
]
CHARGE_BOND = ["-", "+", "=", "#", "$", ":", '/', '\\']
CHIRALITY   = ["@", "@@", "@TH", "@AL", "@SP", "@TB", "@OH"]
DIGITS      = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
MISC        = ["(", ")", "[", "]", ".", "*", "%"]

SPECIAL_TOKENS = ["[START]", "[STOP]", "[PAD]", "[MASK]"]

ALL_TOKENS = SPECIAL_TOKENS + ELEMENTS + CHARGE_BOND + CHIRALITY + DIGITS + MISC

tok_re = sorted([re.escape(tok) for tok in ALL_TOKENS], key=len, reverse=True)
tok_re = "|".join(tok_re)
SMI_REGEX = re.compile(f"({tok_re})")

VOCAB = {k: i for i, k in enumerate(ALL_TOKENS)}
# fmt: on


def collate_strings(
    tokens: list[list[str]] | list[str],
    vocab: dict,
    kekulize: bool = False,
    pad_to_mult: int | None = None,
    maxlen: int | None = None,
) -> Tensor:
    tokens = [tokenize(ex, kekulize=kekulize) if isinstance(ex, str) else ex for ex in tokens]
    if maxlen is not None:
        tokens = [toks[: maxlen - 2] for toks in tokens]

    tokens = [["[START]", *toks, "[STOP]"] for toks in tokens]
    tokens = [[vocab[c] for c in ex] for ex in tokens]

    tensor = pad_sequence([torch.tensor(ex) for ex in tokens], batch_first=True, padding_value=vocab["[PAD]"])

    if pad_to_mult is not None and tensor.size(1) % pad_to_mult != 0:
        pad_len = pad_to_mult - (tensor.size(1) % pad_to_mult)
        tensor = F.pad(tensor, (0, pad_len), value=vocab["[PAD]"])

    return tensor


def tokenize(example: str, kekulize: bool = False) -> list[str]:
    if kekulize:
        try:
            mol = Chem.MolFromSmiles(example)
            Chem.Kekulize(mol)
            example = Chem.MolToSmiles(mol, kekuleSmiles=True)
        except Exception:
            pass

    return SMI_REGEX.findall(example)
