import click
import numpy as np
import pandas as pd
import skfp.filters as skfp_filt
from rdkit import Chem, rdBase
from skfp.preprocessing import MolFromInchiTransformer
from skfp.utils import run_in_parallel

rdBase.DisableLog("rdApp.*")
Chem.SetDefaultPickleProperties(Chem.PropertyPickleOptions.AllProps)

mol_from_inchi = MolFromInchiTransformer(valid_only=True, suppress_warnings=True)

FILTERS = [
    skfp_filt.BeyondRo5Filter(return_indicators=True),
    skfp_filt.BrenkFilter(return_indicators=True),
    skfp_filt.FAF4DruglikeFilter(return_indicators=True),
    skfp_filt.FAF4LeadlikeFilter(return_indicators=True),
    skfp_filt.GhoseFilter(return_indicators=True),
    skfp_filt.GlaxoFilter(return_indicators=True),
    skfp_filt.GSKFilter(return_indicators=True),
    skfp_filt.HaoFilter(return_indicators=True),
    skfp_filt.LipinskiFilter(return_indicators=True),
    skfp_filt.OpreaFilter(return_indicators=True),
    skfp_filt.PfizerFilter(return_indicators=True),
    skfp_filt.REOSFilter(return_indicators=True),
    skfp_filt.RuleOfVeberFilter(return_indicators=True),
    skfp_filt.RuleOfXuFilter(return_indicators=True),
    skfp_filt.ZINCBasicFilter(return_indicators=True),
    skfp_filt.ZINCDruglikeFilter(return_indicators=True),
]
FILTER_NAMES = [
    # [:-6] removes "Filter" from name
    filt.__class__.__name__[:-6]
    for filt in FILTERS
]


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input Parquet file path",
    required=True,
)
@click.option(
    "--output-file",
    type=click.Path(),
    help="Output TSV file path",
    required=False,
)
def get_filters_statistics(input_file: str, output_file: str | None = None) -> None:
    df = pd.read_parquet(input_file, dtype_backend="pyarrow")

    results = run_in_parallel(
        batch_filter_mols,
        data=df["InChI"].tolist(),
        n_jobs=-1,
        batch_size=1000,
        verbose=True,
    )
    results = np.vstack(results)

    df = pd.DataFrame(results, columns=FILTER_NAMES)
    df = df.mean().mul(100).round(2).reset_index()
    df.columns = ["filter", "percentage"]

    if output_file:
        df.to_csv(output_file, index=False, sep="\t")
    else:
        print(df)


def batch_filter_mols(inchis: list[str]) -> np.ndarray:
    mols = mol_from_inchi.transform(inchis)
    filter_indicators = [filt.transform(mols).reshape(-1, 1) for filt in FILTERS]
    return np.hstack(filter_indicators)


if __name__ == "__main__":
    get_filters_statistics()
