import click
import numpy as np
import polars as pl
from skfp.preprocessing import MolFromInchiTransformer
from skfp.utils import run_in_parallel

# basic descriptors, commonly used to describe datasets
DESCRIPTOR_NAMES = [
    "molecular_weight",
    "num_atoms",
    "HBA",
    "HBD",
    "logP",
    "TPSA",
    "num_rotatable_bonds",
]


mol_from_inchi = MolFromInchiTransformer(valid_only=True, suppress_warnings=True)


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input Parquet file path",
    required=True,
)
@click.option(
    "--output-file",
    type=click.Path(),
    help="Output Parquet file path",
    required=False,
)
def calculate_simple_descriptors(input_file: str, output_file: str | None) -> None:
    """
    Analyze distributions of simple topological descriptors, typically checked
    for molecular datasets, e.g. molecular weight.

    :param input_file: path to input Parquet file, with "InChI" column
    :param output_file: path to output Parquet file
    """
    df = pl.read_parquet(input_file, columns=["InChI"])

    results = run_in_parallel(
        mols_batch_descriptors,
        data=df["InChI"],
        n_jobs=-1,
        batch_size=1000,
        verbose=True,
    )
    results = np.vstack(results)

    if output_file:
        df = pl.DataFrame(results, schema=DESCRIPTOR_NAMES)
        df.write_parquet(output_file)
        print(f"Results saved to {output_file}")
    else:
        print_descriptor_statistics(results)


def mols_batch_descriptors(inchis: pl.Series) -> np.ndarray:
    """
    Calculate simple physicochemical descriptors of a molecule.
    """
    from rdkit.Chem.Crippen import MolLogP
    from rdkit.Chem.Descriptors import MolWt
    from rdkit.Chem.rdMolDescriptors import (
        CalcNumHBA,
        CalcNumHBD,
        CalcNumRotatableBonds,
        CalcTPSA,
    )

    mols = mol_from_inchi.transform(inchis)
    descriptors = [
        [
            MolWt(mol),
            mol.GetNumAtoms(),
            CalcNumHBA(mol),
            CalcNumHBD(mol),
            MolLogP(mol),
            CalcTPSA(mol),
            CalcNumRotatableBonds(mol),
        ]
        for mol in mols
    ]

    return np.array(descriptors)


def print_descriptor_statistics(data: np.ndarray) -> None:
    print(
        "",
        "min",
        "perc1",
        "perc5",
        "q1",
        "mean",
        "median",
        "q3",
        "perc95",
        "perc99",
        "max",
        sep="\t",
    )
    for idx, col_name in enumerate(DESCRIPTOR_NAMES):
        print(col_name, end="\t")
        col_vals = data[:, idx]
        perc_1, perc_5, perc_95, perc_99 = np.percentile(col_vals, [1, 5, 95, 99])
        q1, q3 = np.percentile(col_vals, [25, 75])
        print(
            f"{np.min(col_vals):.2f}",
            f"{perc_1:.2f}",
            f"{perc_5:.2f}",
            f"{q1:.2f}",
            f"{np.mean(col_vals):.2f}",
            f"{np.median(col_vals):.2f}",
            f"{q3:.2f}",
            f"{perc_95:.2f}",
            f"{perc_99:.2f}",
            f"{np.max(col_vals):.2f}",
            sep="\t",
        )


if __name__ == "__main__":
    calculate_simple_descriptors()
