import click
import polars as pl
from skfp.preprocessing import MolFromInchiTransformer
from skfp.utils import run_in_parallel

mol_from_inchi = MolFromInchiTransformer(valid_only=True, suppress_warnings=True)


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input Parquet file path",
    required=True,
)
def count_salts(input_file: str) -> None:
    """
    Count salt molecules (disconnected components) in the dataset.
    """
    df = pl.read_parquet(input_file)

    salts_counts = run_in_parallel(
        get_salts_count,
        data=df["InChI"],
        n_jobs=-1,
        batch_size=1000,
        verbose=True,
    )

    num_salts = sum(salts_counts)
    perc_salts = num_salts / len(df)

    print(f"Number of salts: {num_salts}")
    print(f"Percentage of salts: {perc_salts:.2%}")


def get_salts_count(inchis: pl.Series) -> int:
    from rdkit.Chem import GetFormalCharge, GetMolFrags

    mols = mol_from_inchi.transform(inchis)
    count = 0
    for mol in mols:
        frags = GetMolFrags(mol, asMols=True)
        if len(frags) < 2:
            continue

        charges = [GetFormalCharge(frag) for frag in frags]
        if any(c < 0 for c in charges) and any(c > 0 for c in charges):
            count += 1

    return count


if __name__ == "__main__":
    count_salts()
