from collections import defaultdict

import click
import polars as pl
from rdkit import rdBase
from skfp.preprocessing import MolFromInchiTransformer
from skfp.utils import run_in_parallel

mol_from_inchi = MolFromInchiTransformer(suppress_warnings=True)
rdBase.DisableLog("rdApp.*")

ATOM_GROUPS = {
    "C": {"C"},
    "N": {"N"},
    "O": {"O"},
    "S": {"S"},
    "Halogens": {"At", "Br", "Cl", "F", "I"},
    "Metalloids": {"As", "B", "Ge", "Sb", "Si", "Te"},
    "Metals": {
        "Ag",
        "Au",
        "Bi",
        "Cd",
        "Co",
        "Cr",
        "Cu",
        "Cs",
        "Fe",
        "Fr",
        "Hg",
        "Ir",
        "K",
        "Li",
        "Mn",
        "Mo",
        "Na",
        "Nb",
        "Ni",
        "Os",
        "Pb",
        "Pd",
        "Pt",
        "Rb",
        "Re",
        "Rh",
        "Ru",
        "Ta",
        "Tc",
        "V",
        "W",
        "Y",
        "Zn",
    },
    "Other": set(),
}


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    required=True,
    help="Input Parquet file path",
)
@click.option(
    "--output-file",
    type=click.Path(),
    required=False,
    help="Output csv file for atom statistics",
)
def main(
    input_file: str,
    output_file: str | None = None,
) -> None:
    df = pl.read_parquet(input_file)

    results = run_in_parallel(
        batch_atom_stats,
        data=df["InChI"],
        n_jobs=-1,
        batch_size=1000,
        verbose=True,
    )

    all_counts = defaultdict(int)

    for batch_atom_count in results:
        for atom, val in batch_atom_count.items():
            all_counts[atom] += val

    total_mols = len(df)
    stats_df = pl.DataFrame(
        [
            (atom, count, (count / total_mols) * 100)
            for atom, count in sorted(all_counts.items())
        ],
        schema=["atom", "count", "percentage_present"],
    )

    if output_file is None:
        print(stats_df.write_csv(output_file))
    else:
        stats_df.write_csv(output_file)
        print(f"Atom stats saved to {output_file}")


def batch_atom_stats(inchis: list[str]) -> dict[str, int]:
    mols = mol_from_inchi.transform(inchis)

    atom_counters = defaultdict(int)
    batch_counters = []

    for mol in mols:
        if mol is None:
            continue
        atom_symbols = {atom.GetSymbol() for atom in mol.GetAtoms()}
        group_presence = dict.fromkeys(ATOM_GROUPS, 0)
        matched = set()

        for symbol in atom_symbols:
            found = False
            for group, elements in ATOM_GROUPS.items():
                if symbol in elements:
                    group_presence[group] = 1
                    matched.add(symbol)
                    found = True
                    break

            if not found:
                group_presence["Other"] = 1

        batch_counters.append(group_presence)

    for atom_dict in batch_counters:
        for atom, val in atom_dict.items():
            atom_counters[atom] += val

    return atom_counters


if __name__ == "__main__":
    main()
