import click
import polars as pl
from skfp.utils import run_in_parallel


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input Parquet file path",
    required=True,
)
@click.option(
    "--as-generic-cyclic-skeleton",
    is_flag=True,
    help="Return generic cyclic skeleton instead of specific Bemis-Murcko scaffold",
    required=False,
)
def count_scaffolds(input_file: str, as_generic_cyclic_skeleton: bool = False) -> None:
    """
    Count unique Bemis-Murcko scaffolds in the dataset.
    """
    df = pl.read_parquet(input_file)

    scaffolds = run_in_parallel(
        batch_get_scaffolds,
        data=df["InChI"],
        n_jobs=-1,
        batch_size=1000,
        flatten_results=True,
        verbose=True,
        as_generic_cyclic_skeleton=as_generic_cyclic_skeleton,
    )

    df = pl.Series(scaffolds)
    num_scaffolds = df.drop_nulls().value_counts(parallel=True).select(pl.len()).item()
    perc_scaffolds = num_scaffolds / len(df)

    print(f"Number of distinct scaffolds: {num_scaffolds}")
    print(f"Percentage of scaffolds: {perc_scaffolds:.2%}")


def batch_get_scaffolds(
    inchis: pl.Series, as_generic_cyclic_skeleton: bool = False
) -> list[str | None]:
    """
    Extract specific and generic Bemis-Murcko scaffolds from a DataFrame of molecules.
    """
    scaffolds = [
        get_bemis_murcko_scaffold(
            inchi=inchi, as_generic_cyclic_skeleton=as_generic_cyclic_skeleton
        )
        for inchi in inchis
    ]
    return scaffolds


def get_bemis_murcko_scaffold(
    inchi: str, as_generic_cyclic_skeleton: bool = False
) -> str | None:
    from rdkit.Chem import MolFromInchi, MolToSmiles, SanitizeMol
    from rdkit.Chem.Scaffolds import MurckoScaffold
    from skfp.utils import no_rdkit_logs

    """
    Get Bemis-Murcko scaffold from a InChI string.
    If `as_generic_cyclic_skeleton` is True, return the generic cyclic skeleton
    """
    with no_rdkit_logs(suppress_warnings=True):
        try:
            mol = MolFromInchi(inchi)

            # Here we retain exo atoms, which doesn't full adhere to the original Bemis-Murcko definition
            scaffold = MurckoScaffold.GetScaffoldForMol(mol)

            if as_generic_cyclic_skeleton:
                scaffold = MurckoScaffold.MakeScaffoldGeneric(scaffold)

                # Second call of `GetScaffoldForMol` will remove leftovers after exo-substituents, which
                # weren't removed by `MakeScaffoldGeneric`
                scaffold = MurckoScaffold.GetScaffoldForMol(scaffold)

            SanitizeMol(scaffold)
            return MolToSmiles(scaffold)

        except Exception:
            return None


if __name__ == "__main__":
    count_scaffolds()
