import click
import numpy as np
import polars as pl
from rdkit.Chem import MolFromInchi
from skfp.distances import (
    bulk_tanimoto_binary_distance,
)
from skfp.fingerprints import ECFPFingerprint
from skfp.utils import run_in_parallel


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input Parquet file path",
    required=True,
)
@click.option(
    "--output-file",
    type=click.Path(),
    help="Output .txt file path",
    required=False,
)
def compute_normalized_n_circles(input_file: str, output_file: str | None) -> None:
    """
    Compute the normalized #Circles measure, using parallel approximation algorithm
    (simplified Algorithm 4 from original paper). #Circles is divided by dataset size
    to enable direct comparisons between datasets.

    Reference paper: https://arxiv.org/abs/2112.12542

    :param input_file: Path to the input Parquet file containing SMILES.
    :param output_file: Path to the output .txt file.
    :param threshold: Distance threshold for determining circles.
    """
    df = pl.read_parquet(input_file, columns="InChI")
    inchis = df["InChI"].shuffle(0).to_list()

    results = run_in_parallel(
        get_n_circles_batch,
        data=inchis,
        n_jobs=-1,
        batch_size=1000000,
        verbose=True,
    )

    n_circles = sum(results)
    normalized_n_circles = n_circles / len(df)

    if output_file:
        with open(output_file, "w") as file:
            file.write(str(n_circles) + "\n")
            file.write(str(normalized_n_circles) + "\n")
    else:
        print(f"#Circles: {n_circles}")
        print(f"Normalized #Circles: {normalized_n_circles:.2%}")


def get_n_circles_batch(inchis: list[str]) -> int:
    from rdkit import RDLogger

    RDLogger.DisableLog("rdApp.*")  # turn off unnecessary warnings

    fp = ECFPFingerprint(fp_size=1024)  # original paper setting

    # take first InChI as the first circle
    circles_fps = fp.transform([MolFromInchi(inchis[0])])
    inchis = inchis[1:]

    n_circles = 1
    for smi in inchis:
        mol = MolFromInchi(smi)
        if not mol:
            continue

        mol_fp = fp.transform([mol])
        dists = bulk_tanimoto_binary_distance(circles_fps, mol_fp)

        # threshold follows the original paper
        if np.min(dists) > 0.75:
            n_circles += 1
            circles_fps = np.vstack([circles_fps, mol_fp.reshape(1, -1)])

    return n_circles


if __name__ == "__main__":
    compute_normalized_n_circles()
