import numpy as np
from sklearn.ensemble import HistGradientBoostingClassifier

"""
Algorithm from "Classifying without discriminating"
https://ieeexplore.ieee.org/abstract/document/4909197
"""


def rank(
    features: np.ndarray,
    labels: np.ndarray,
    in_group: np.ndarray,
    learn_on_majority: bool = False,
    in_collective: np.ndarray = None,
    seed=0,
):
    # Remove column from features if it is identical to in_group
    # in_group_column = in_group.astype(int)
    col_to_remove = in_group.astype(float).reshape(-1, 1)
    features = features[:, ~np.all(features == col_to_remove, axis=0)]
    # if learn_on_majority:
    #     features = features[~in_group]
    #     labels = labels[~in_group]
    ranker = HistGradientBoostingClassifier(random_state=seed)
    if learn_on_majority:
        ranker.fit(features[~in_group], labels[~in_group])
    else:
        ranker.fit(features, labels)

    promotable = np.logical_and(in_group, labels == 0)
    if in_collective is not None:
        promotable = np.logical_and(promotable, in_collective)
    promotable = np.flatnonzero(promotable)
    demoteable = np.flatnonzero(np.logical_and(np.logical_not(in_group), labels == 1))
    class_proba = ranker.predict_proba(features)[:, 1]

    sorted_idx = np.argsort(class_proba)
    to_promote = sorted_idx[np.isin(sorted_idx, promotable)]
    to_promote = to_promote[::-1]
    to_demote = sorted_idx[np.isin(sorted_idx, demoteable)]

    return to_promote, to_demote


def get_flip_number(
    labels: np.ndarray,
    in_group: np.ndarray,
):
    sens = np.count_nonzero(in_group)
    not_send = np.count_nonzero(np.logical_not(in_group))

    not_sens_pos = np.count_nonzero(
        np.logical_and(np.logical_not(in_group), labels == 1)
    )
    sens_pos = np.count_nonzero(np.logical_and(in_group, labels == 1))

    num_to_flip = ((sens * not_sens_pos) - (not_send * sens_pos)) // (sens + not_send)
    return num_to_flip


def cnd(
    features: np.ndarray,
    labels: np.ndarray,
    in_group: np.ndarray,
    num_to_flip: int = None,
    seed: int = 0,
):
    to_promote, to_demote = rank(
        features, labels, in_group, seed, learn_on_majority=False
    )
    if num_to_flip is None:
        num_to_flip = get_flip_number(labels, in_group)
    new_labels = labels.copy()
    new_labels[to_promote[:num_to_flip]] = 1
    new_labels[to_demote[:num_to_flip]] = 0
    return new_labels


def collective_cnd(
    features: np.ndarray,
    labels: np.ndarray,
    in_group: np.ndarray,
    in_collective: np.ndarray,
    num_to_flip: int = None,
    seed: int = 0,
):
    to_promote, _ = rank(
        features=features,
        labels=labels,
        in_group=in_group,
        learn_on_majority=True,
        in_collective=in_collective,
        seed=seed,
    )
    if num_to_flip is None:
        num_to_flip = get_flip_number(labels, in_group)
    new_labels = labels.copy()
    new_labels[to_promote[:num_to_flip]] = 1
    return new_labels
