import click
import numpy as np
import polars as pl
from rdkit import rdBase
from rdkit.Contrib.SA_Score.sascorer import calculateScore as sa_score
from skfp.preprocessing import MolFromInchiTransformer
from skfp.utils import run_in_parallel

mol_from_inchi = MolFromInchiTransformer(suppress_warnings=True, valid_only=True)

rdBase.DisableLog("rdApp.*")


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input Parquet file path",
    required=True,
)
@click.option(
    "--output-file",
    type=click.Path(),
    help="Output Parquet file path",
    required=False,
)
def get_synthesizability_scores(
    input_file: str, output_file: str | None = None
) -> None:
    df = pl.read_parquet(input_file)

    results = run_in_parallel(
        batch_score_synthesizability,
        data=df["InChI"],
        n_jobs=-1,
        batch_size=1000,
        verbose=True,
    )

    results = np.concatenate(results)
    result_df = pl.DataFrame(results, schema=["SA_SCORE"])

    if output_file:
        result_df.write_parquet(output_file)
        print(f"Results saved to {output_file}")
    else:
        print_statistics(results)


def batch_score_synthesizability(inchis: pl.Series) -> np.ndarray:
    mols = mol_from_inchi.transform(inchis)
    sa_scores = np.array([sa_score(mol) for mol in mols if mol is not None])
    return sa_scores


def print_statistics(data: np.ndarray) -> None:
    print(
        "min",
        "perc1",
        "perc5",
        "q1",
        "mean",
        "median",
        "q3",
        "perc95",
        "perc99",
        "max",
        sep="\t",
    )
    perc_1, perc_5, perc_95, perc_99 = np.percentile(data, [1, 5, 95, 99])
    q1, q3 = np.percentile(data, [25, 75])
    print(
        f"{np.min(data):.2f}",
        f"{perc_1:.2f}",
        f"{perc_5:.2f}",
        f"{q1:.2f}",
        f"{np.mean(data):.2f}",
        f"{np.median(data):.2f}",
        f"{q3:.2f}",
        f"{perc_95:.2f}",
        f"{perc_99:.2f}",
        f"{np.max(data):.2f}",
        sep="\t",
    )


if __name__ == "__main__":
    get_synthesizability_scores()
