"""Continuous discrete mutual-information.

Taken from: https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0087357&type=printable
"""
import numpy as np


def _make_bins_fixed_width(x: np.ndarray, n_bins: int) -> np.ndarray:
    assert len(x.shape) == 1
    cmax = np.max(x)
    cmin = np.min(x)
    bin_width = (cmax - cmin) / n_bins

    bin_assignments = np.zeros(x.shape, dtype=np.int32)

    for i in range(n_bins):
        if i == 0:
            bmin = -np.infty
        else:
            bmin = cmin + i * bin_width

        if i == n_bins - 1:
            bmax = np.infty
        else:
            bmax = cmin + (i + 1) * bin_width

        bin_assignments[(bmin <= x) & (x < bmax)] = i

    return bin_assignments


# TODO: Maybe something top 95% of range has equal number of examples in bins, lowest 5% is its own bin.


def compute_mi_fixed_width_bins(continuous: np.ndarray, discrete: np.ndarray, n_bins: int):
    # NOTE: This is probably pretty inefficient.

    assert len(continuous.shape) == len(discrete.shape) == 1
    assert np.issubdtype(discrete.dtype, np.integer) or discrete.dtype == np.bool

    discrete_vals = np.array(list(set(discrete)))

    bin_assignments = _make_bins_fixed_width(continuous, n_bins=n_bins)

    p_bins = np.array([
        (bin_assignments == b).astype(np.float64).mean()
        for b in bin_assignments
    ])

    p_discs = np.array([
        (discrete == d).astype(np.float64).mean()
        for d in discrete_vals
    ])

    p_bd = np.array([
        [
            ((bin_assignments == b) & (discrete == d)).astype(np.float64).mean()
            for d in discrete_vals
        ]
        for b in bin_assignments
    ])

    p_bins = p_bins[:, None]
    p_discs = p_discs[None, :]

    mis = p_bd * (np.log2(p_bd) - np.log2(p_bins) - np.log2(p_discs))
    mis[~np.isfinite(mis)] = 0.0

    return mis.sum()


# def deprecated_compute_mi(continuous: np.ndarray, discrete: np.ndarray, k: int):
#     from scipy.special import digamma
#     # Taken from: https://journals.plos.org/plosone/article/file?id=10.1371/journal.pone.0087357&type=printable
#     # This seems to be not a good fit for what I'm using it for.

#     # NOTE: This implementation is not the fastest and probably has issues when there
#     # are a lot of data points.
#     assert len(continuous.shape) == len(discrete.shape) == 1
#     assert np.issubdtype(discrete.dtype, np.integer) or discrete.dtype == np.bool

#     N = continuous.shape[0]

#     mi = 0

#     discrete_vals = set(discrete)
#     for val in discrete_vals:
#         subset_continuous = continuous[discrete == val]

#         N_val = subset_continuous.shape[0]

#         subset_distances = np.abs(subset_continuous[:, None] - subset_continuous[None, :])
#         subset_to_all_distances = np.abs(subset_continuous[:, None] - continuous[None, :])

#         subset_sorted_distances = np.sort(subset_distances, axis=-1)

#         # The k + 1 is needed to ignore the distance from the datapoint to itself.
#         d = subset_sorted_distances[:, k + 1]

#         # The -= 1 is needed to exclude the distance from the datapoint to itself.
#         m = (subset_to_all_distances <= d[:, None]).astype(np.int32).sum(axis=-1)
#         m -= 1

#         pe_mis = digamma(N) - digamma(N_val) + digamma(k) - digamma(m)
#         mi += pe_mis.sum() * N_val / N

#     return mi
