import re

import click
import numpy as np
import polars as pl
from rdkit.Contrib.IFG.ifg import identify_functional_groups
from skfp.preprocessing import MolFromInchiTransformer
from skfp.utils import run_in_parallel

ATOMS_PATTERN = re.compile(r"[A-Z][a-z]?")
mol_from_inchi = MolFromInchiTransformer(valid_only=True, suppress_warnings=True)


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input Parquet file path",
    required=True,
)
def count_functional_groups(input_file: str) -> None:
    """
    Count unique functional groups in the dataset. Uses Ertl's algorithm for
    functional groups detection. Only considers those between 2 and 20 atoms
    (inclusive), and appearing in at least 10 molecules, following the settings
    from the original paper.

    :param input_file: path to input Parquet file, with "InChI" column
    """
    df = pl.read_parquet(input_file)

    # Process molecules in parallel to get lists of functional groups
    results = run_in_parallel(
        get_functional_groups,
        data=df["InChI"].to_list(),
        n_jobs=-1,
        batch_size=1000,
        verbose=True,
    )

    all_fgs = np.concatenate(results)

    fg_series = pl.Series("functional_group", all_fgs)
    fg_counts = fg_series.value_counts()

    filtered_counts = fg_counts.filter(pl.col("count") >= 10)

    print(f"Number of functional groups: {len(filtered_counts)}")


def get_functional_groups(inchis: list[str]) -> np.ndarray:
    """
    Get functional groups from a batch of molecules.
    """
    from rdkit import rdBase
    from rdkit.Chem import MolFromSmiles

    rdBase.DisableLog("rdApp.*")
    mols = mol_from_inchi.transform(inchis)

    all_fgs = []
    for mol in mols:
        mol_func_groups = {fg.atoms for fg in identify_functional_groups(mol)}
        for fg in mol_func_groups:
            fg_mol = MolFromSmiles(fg, sanitize=False)
            if 2 <= fg_mol.GetNumAtoms() <= 20:
                all_fgs.append(fg)

    return np.array(all_fgs)


if __name__ == "__main__":
    count_functional_groups()
