import random
from collections import defaultdict

import click
import numpy as np
import polars as pl
from rdkit.SimDivFilters import MaxMinPicker
from skfp.preprocessing import MolFromInchiTransformer
from skfp.utils import run_in_parallel
from tqdm import tqdm

mol_from_inchi = MolFromInchiTransformer(valid_only=True, suppress_warnings=True)


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input parquet file path",
    required=True,
)
@click.option(
    "--output-file",
    type=click.Path(dir_okay=False),
    help="Output Parquet file path with the selected subset",
    required=True,
)
@click.option(
    "--subset-size",
    type=int,
    help="Total target size of the subset",
    required=True,
)
def build_and_save_subset(
    input_file: str,
    output_file: str,
    subset_size: int,
):
    """
    Build a diverse subset and save to a Parquet file.

    :param input_file: input Parquet file path
    :param output_file: output Parquet file path with the selected subset
    :param subset_size: total target size of the subset
    """
    df = pl.read_parquet(input_file)

    fps = run_in_parallel(
        get_fingerprints,
        data=df["InChI"],
        n_jobs=-1,
        batch_size=10000,
        flatten_results=True,
        verbose=True,
    )

    print("Start maximum diversity picking for centers")
    picker = MaxMinPicker()
    centers_idxs, _ = picker.LazyBitVectorPickWithThreshold(
        fps, poolSize=len(fps), pickSize=len(fps), threshold=0.9, seed=0
    )
    centers_idxs = list(centers_idxs)

    for subset_size in [100000, 1000000]:  # , 5000000, 10000000]:
        if subset_size == 100000:
            output_file = "outputs/molpile_subset_100k.parquet"
        elif subset_size == 1000000:
            output_file = "outputs/molpile_subset_1M.parquet"
        elif subset_size == 5000000:
            output_file = "outputs/molpile_subset_5M.parquet"
        else:
            output_file = "outputs/molpile_subset_10M.parquet"

        subset_df, selected_idxs = select_diverse_subset(
            df, fps, centers_idxs, subset_size
        )

        # we may get fewer molecules than necessary if we have some small clusters,
        # i.e. containing fewer molecules than we want
        # in those cases we select missing ones randomly - selected subset already
        # takes care of diversity
        if len(subset_df) < subset_size:
            subset_df = ensure_proper_size(
                df_full=df,
                df_diverse_subset=subset_df,
                selected_diverse_idxs=selected_idxs,
                target_subset_size=subset_size,
            )

        subset_df.write_parquet(output_file)
        print(f"Saved {len(subset_df)} molecules to {output_file}")


def get_fingerprints(inchis: pl.Series):
    from rdkit import rdBase
    from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator

    rdBase.DisableLog("rdApp.*")  # turn off unnecessary warnings

    mols = mol_from_inchi.transform(inchis)
    fps = GetMorganGenerator().GetFingerprints(mols)
    return fps


def select_diverse_subset(
    df: pl.DataFrame, fps: list, centers_idxs: list[int], subset_size: int
) -> tuple[pl.DataFrame, list[int]]:
    all_idxs = list(range(len(fps)))

    # always assign each center index to its own cluster
    # and remove those center indices from the pool of all_idxs
    clusters = {c: [c] for c in centers_idxs}
    all_idxs = [i for i in all_idxs if i not in centers_idxs]

    rng = random.Random(0)
    rng.shuffle(all_idxs)

    n_centers = len(centers_idxs)
    subset_size = min(subset_size, len(all_idxs))
    n_mols_per_center = (subset_size // n_centers) - 1  # centers are already assigned

    centers_inchis = df[centers_idxs]["InChI"]

    print("Picked centers, start cluster assignment")

    # cluster assignment is parallelized for efficiency
    # we do this in batches and parallelize inside them to finish quickly if possible
    batch_size = 100000
    num_batches = np.ceil(len(all_idxs) / batch_size)
    for idxs_batch in tqdm(batch_iter(all_idxs, batch_size), total=num_batches):
        data = pl.DataFrame({"InChI": df[idxs_batch]["InChI"], "idx": idxs_batch})

        clusters_batches = run_in_parallel(
            assign_clusters,
            data,
            n_jobs=-1,
            batch_size=10000,
            centers_inchis=centers_inchis,
            centers_idxs=centers_idxs,
        )
        for clusters_batch in clusters_batches:
            for cluster_idx, mol_idxs in clusters_batch.items():
                clusters[cluster_idx].extend(mol_idxs)
                clusters[cluster_idx] = clusters[cluster_idx][:n_mols_per_center]

    print("Created clusters, selecting and saving to file")
    selected_idxs = [i for members in clusters.values() for i in members]
    subset_df = df[selected_idxs]

    return subset_df, selected_idxs


def batch_iter(iterable: list, batch_size: int):
    for i in range(0, len(iterable), batch_size):
        yield iterable[i : i + batch_size]


def assign_clusters(
    data: pl.DataFrame,
    centers_inchis: pl.Series,
    centers_idxs: list[int],
) -> dict:
    from rdkit.DataStructs import BulkTanimotoSimilarity

    fps = get_fingerprints(data["InChI"])
    centers_fps = get_fingerprints(centers_inchis)

    clusters = defaultdict(list)
    for idx, fp in zip(data["idx"], fps, strict=False):
        sims = BulkTanimotoSimilarity(fp, centers_fps)
        best_center = centers_idxs[np.argmax(sims)]
        clusters[best_center].append(idx)

    return clusters


def ensure_proper_size(
    df_full: pl.DataFrame,
    df_diverse_subset: pl.DataFrame,
    selected_diverse_idxs: list[int],
    target_subset_size: int,
) -> pl.DataFrame:
    needed = target_subset_size - len(df_diverse_subset)

    # select rows not in diverse subset
    all_idx = np.arange(len(df_full))
    mask = np.ones(len(df_full), dtype=bool)
    mask[selected_diverse_idxs] = False
    remaining_idx = all_idx[mask]

    # sample randomly from remaining and combine
    rng = np.random.default_rng(seed=0)
    rest_indices = rng.choice(remaining_idx, size=needed, replace=False)
    df_random = df_full[rest_indices]

    df_diverse_subset = pl.concat([df_diverse_subset, df_random])
    return df_diverse_subset


if __name__ == "__main__":
    build_and_save_subset()
