import random

import click
import numpy as np
import polars as pl
from rdkit.DataStructs import TanimotoSimilarity
from scipy.stats import ks_2samp, wasserstein_distance
from skfp.preprocessing import MolFromInchiTransformer
from skfp.utils import run_in_parallel

rng = random.Random(0)
mol_from_inchi = MolFromInchiTransformer(valid_only=True, suppress_warnings=True)


@click.command()
@click.option(
    "--source-dataset",
    type=click.Path(exists=True),
    help="Parquet file with source of diverse molecules dataset path",
    required=True,
)
@click.option(
    "--diverse-subset",
    type=click.Path(exists=True),
    help="Parquet file with diverse molecules dataset path",
    required=True,
)
@click.option("--n", type=int, default=10**6, help="Number of random pairs to sample")
def main(source_dataset, diverse_subset, n):
    """
    Perform dataset diversity comparison for diverse subset and random subset of a dataset.

    Based on statistics for Tanimoto distance (not similarity!), so for example greater average
    means more different molecules on average.
    """
    df_diverse = pl.read_parquet(diverse_subset)
    size = len(df_diverse)

    # select a randomized subset from the whole dataset
    df_random = pl.read_parquet(source_dataset)
    idxs = rng.sample(range(size), size)
    df_random = df_random[idxs]

    # sample n random pairs of molecules for calculating distances
    # we can also subset DataFrames to just those
    pairs, indices = sample_pairs(size, n)
    df_random = df_random.with_row_index("idx").filter(pl.col("idx").is_in(indices))
    df_diverse = df_diverse.with_row_index("idx").filter(pl.col("idx").is_in(indices))

    fps_random = run_in_parallel(
        get_ecfp_fingerprints,
        data=df_random["InChI"],
        n_jobs=-1,
        batch_size=10000,
        flatten_results=True,
        verbose=True,
    )
    fps_random = dict(zip(df_random["idx"].to_list(), fps_random, strict=False))

    fps_diverse = run_in_parallel(
        get_ecfp_fingerprints,
        data=df_diverse["InChI"],
        n_jobs=-1,
        batch_size=10000,
        flatten_results=True,
        verbose=True,
    )
    fps_diverse = dict(zip(df_diverse["idx"].to_list(), fps_diverse, strict=False))

    dist_random = tanimoto_distances_distribution(fps_random, pairs)
    dist_diverse = tanimoto_distances_distribution(fps_diverse, pairs)

    compare_distributions(dist_random=dist_random, dist_diverse=dist_diverse)


def sample_pairs(n: int, k: int) -> tuple[set[int], list[tuple[int, int]]]:
    """
    Sample k unique pairs of indices (i, j) with i != j from indices 0 to n-1.
    """
    pairs = set()
    indices = set()
    while len(pairs) < k:
        i, j = rng.sample(range(n), 2)
        if i <= j:
            continue

        if (i, j) not in pairs:
            pairs.add((i, j))
            indices.add(i)
            indices.add(j)

    return pairs, list(indices)


def get_ecfp_fingerprints(inchis: pl.Series):
    from rdkit import rdBase
    from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator

    rdBase.DisableLog("rdApp.*")  # turn off unnecessary warnings

    mols = mol_from_inchi.transform(inchis)
    fps = GetMorganGenerator().GetFingerprints(mols)

    return fps


def tanimoto_distances_distribution(
    idx_to_fp: dict, pairs: list[tuple[int, int]]
) -> np.ndarray:
    sims = [1 - TanimotoSimilarity(idx_to_fp[i], idx_to_fp[j]) for i, j in pairs]
    return np.array(sims)


def compare_distributions(dist_random: np.ndarray, dist_diverse: np.ndarray) -> None:
    dist_random = np.array(dist_random)
    dist_diverse = np.array(dist_diverse)

    stats = {
        "mean_random": float(np.mean(dist_random)),
        "mean_diverse": float(np.mean(dist_diverse)),
        "median_random": float(np.median(dist_random)),
        "median_diverse": float(np.median(dist_diverse)),
        "gmean_random": float(np.exp(np.mean(np.log(dist_random + 1e-12)))),
        "gmean_diverse": float(np.exp(np.mean(np.log(dist_diverse + 1e-12)))),
        "wasserstein": float(wasserstein_distance(dist_random, dist_diverse)),
    }

    percentiles = [10, 25, 75, 90]
    for p in percentiles:
        stats[f"p{p}_random"] = float(np.percentile(dist_random, p))
        stats[f"p{p}_diverse"] = float(np.percentile(dist_diverse, p))

    ks_stat, _ = ks_2samp(dist_random, dist_diverse)
    stats["ks_stat"] = float(ks_stat)

    for key, value in stats.items():
        print(f"{key}: {value:.3f}")


if __name__ == "__main__":
    main()
