"""
Helper functions for finding appropriate thresholds.

Many agents use classifiers to calculate continuous scores and then use a
threshold to transform those scores into decisions that optimize some reward.
The helper functions in this module are intended to aid with choosing those
thresholds.

Code based on: 
https://github.com/google/ml-fairness-gym
"""

import bisect
import enum
from absl import logging
import attr
import numpy as np
import scipy.optimize
import scipy.spatial
from sklearn import metrics as sklearn_metrics


class ThresholdPolicy(enum.Enum):
    SINGLE_THRESHOLD = "single_threshold"
    MAXIMIZE_REWARD = "maximize_reward"
    EQUALIZE_OPPORTUNITY = "equalize_opportunity"


@attr.s
class RandomizedThreshold(object):
    """Represents a distribution over decision thresholds."""

    values = attr.ib(factory=lambda: [0.0])
    weights = attr.ib(factory=lambda: [1.0])
    rng = attr.ib(factory=np.random.RandomState)
    tpr_target = attr.ib(default=None)

    def smoothed_value(self):
        # If one weight is small, this is probably an optimization artifact.
        # Snap to a single threshold.
        if len(self.weights) == 2 and min(self.weights) < 1e-4:
            return self.values[np.argmax(self.weights)]
        return np.dot(self.weights, self.values)

    def sample(self):
        return self.rng.choice(self.values, p=self.weights)

    def iteritems(self):
        return zip(self.weights, self.values)


def convex_hull_roc(roc):
    """Returns an roc curve without the points inside the convex hull.

    Points below the fpr=tpr line corresponding to random performance are also
    removed.

    Args:
      roc: A tuple of lists that are all the same length, containing
        (false_positive_rates, true_positive_rates, thresholds). This is the same
        format returned by sklearn.metrics.roc_curve.
    """
    fprs, tprs, thresholds = roc
    if np.isnan(fprs).any() or np.isnan(tprs).any():
        logging.debug("Convex hull solver does not handle NaNs.")
        return roc
    if len(fprs) < 3:
        logging.debug("Convex hull solver does not curves with < 3 points.")
        return roc
    try:
        # Add (fpr=1, tpr=0) to the convex hull to remove any points below the
        # random-performance line.
        hull = scipy.spatial.ConvexHull(np.vstack([fprs + [1], tprs + [0]]).T)
    except scipy.spatial.qhull.QhullError:
        logging.debug("Convex hull solver failed.")
        return roc
    verticies = set(hull.vertices)

    return (
        [fpr for idx, fpr in enumerate(fprs) if idx in verticies],
        [tpr for idx, tpr in enumerate(tprs) if idx in verticies],
        [thresh for idx, thresh in enumerate(thresholds) if idx in verticies],
    )


def _threshold_from_tpr(roc, tpr_target, rng):
    """Returns a `RandomizedThreshold` that achieves `tpr_target`.

    For an arbitrary value of tpr_target in [0, 1], there may not be a single
    threshold that achieves that tpr_value on our data. In this case, we
    interpolate between the two closest achievable points on the discrete ROC
    curve.

    See e.g., Theorem 1 of Scott et al (1998)
    "Maximum realisable performance: a principled method for enhancing
    performance by using multiple classifiers in variable cost problem domains"
    http://mi.eng.cam.ac.uk/reports/svr-ftp/auto-pdf/Scott_tr320.pdf

    Args:
      roc: A tuple (fpr, tpr, thresholds) as returned by sklearn's roc_curve
        function.
      tpr_target: A float between [0, 1], the target value of TPR that we would
        like to achieve.
      rng: A `np.RandomState` object that will be used in the returned
        RandomizedThreshold.
    Return: A RandomizedThreshold that achieves the target TPR value.
    """
    # First filter out points that are not on the convex hull.
    _, tpr_list, thresh_list = convex_hull_roc(roc)

    idx = bisect.bisect_left(tpr_list, tpr_target)

    # TPR target is larger than any of the TPR values in the list. In this case,
    # take the highest threshold possible.
    if idx == len(tpr_list):
        return RandomizedThreshold(
            weights=[1], values=[thresh_list[-1]], rng=rng, tpr_target=tpr_target
        )

    # TPR target is exactly achievable by an existing threshold. In this case,
    # do not randomize between two different thresholds. Use a single threshold
    # with probability 1.
    if tpr_list[idx] == tpr_target:
        return RandomizedThreshold(
            weights=[1], values=[thresh_list[idx]], rng=rng, tpr_target=tpr_target
        )

    # Interpolate between adjacent thresholds. Since we are only considering
    # points on the convex hull of the roc curve, we only need to consider
    # interpolating between pairs of adjacent points.
    alpha = _interpolate(x=tpr_target, low=tpr_list[idx - 1], high=tpr_list[idx])
    return RandomizedThreshold(
        weights=[alpha, 1 - alpha],
        values=[thresh_list[idx - 1], thresh_list[idx]],
        rng=rng,
        tpr_target=tpr_target,
    )


def _interpolate(x, low, high):
    """returns a such that a*low + (1-a)*high = x."""
    assert low <= x <= high, (
        "x is not between [low, high]: Expected %s <= %s <=" " %s"
    ) % (low, x, high)
    alpha = 1 - ((x - low) / (high - low))
    assert np.abs(alpha * low + (1 - alpha) * high - x) < 1e-6
    return alpha


def single_threshold(predictions, labels, weights, cost_matrix):
    """Finds a single threshold that maximizes reward.

    Args:
      predictions: A list of float predictions.
      labels: A list of binary labels.
      weights: A list of instance weights.
      cost_matrix: A CostMatrix.

    Returns:
      A single threshold that maximizes reward.
    """
    threshold = equality_of_opportunity_thresholds(
        {"dummy": predictions}, {"dummy": labels}, {"dummy": weights}, cost_matrix
    )["dummy"]
    return threshold.smoothed_value()


def equality_of_opportunity_thresholds(
    group_predictions, group_labels, group_weights, cost_matrix, rng=None
):
    """Finds thresholds that equalize opportunity while maximizing reward.

    Using the definition from "Equality of Opportunity in Supervised Learning" by
    Hardt et al., equality of opportunity constraints require that the classifier
    have equal true-positive rate for all groups and can be enforced as a
    post-processing step on a threshold-based binary classifier by creating
    group-specific thresholds.

    Since there are many different thresholds where equality of opportunity
    constraints can hold, we simultaneously maximize reward described by a reward
    matrix.

    Args:
      group_predictions: A dict mapping from group identifiers to predictions for
        instances from that group.
      group_labels: A dict mapping from group identifiers to labels for instances
        from that group.
      group_weights: A dict mapping from group identifiers to weights for
        instances from that group.
      cost_matrix: A CostMatrix.
      rng: A `np.random.RandomState`.

    Returns:
      A dict mapping from group identifiers to thresholds such that recall is
      equal for all groups.

    Raises:
      ValueError if the keys of group_predictions and group_labels are not the
        same.
    """

    if set(group_predictions.keys()) != set(group_labels.keys()):
        raise ValueError("group_predictions and group_labels have mismatched keys.")

    if rng is None:
        rng = np.random.RandomState()

    groups = sorted(group_predictions.keys())
    roc = {}

    if group_weights is None:
        group_weights = {}

    for group in groups:
        if group not in group_weights or group_weights[group] is None:
            # If weights is unspecified, use equal weights.
            group_weights[group] = [1 for _ in group_labels[group]]

        assert (
            len(group_labels[group])
            == len(group_weights[group])
            == len(group_predictions[group])
        )

        fprs, tprs, thresholds = sklearn_metrics.roc_curve(
            y_true=group_labels[group],
            y_score=group_predictions[group],
            sample_weight=group_weights[group],
        )

        roc[group] = (fprs, np.nan_to_num(tprs), thresholds)

    def negative_reward(tpr_target):
        """Returns negative reward suitable for optimization by minimization."""

        my_reward = 0
        for group in groups:
            weights_ = []
            predictions_ = []
            labels_ = []
            for thresh_prob, threshold in _threshold_from_tpr(
                roc[group], tpr_target, rng=rng
            ).iteritems():
                labels_.extend(group_labels[group])
                for weight, prediction in zip(
                    group_weights[group], group_predictions[group]
                ):
                    weights_.append(weight * thresh_prob)
                    predictions_.append(prediction >= threshold)
            confusion_matrix = sklearn_metrics.confusion_matrix(
                labels_, predictions_, sample_weight=weights_
            )

            my_reward += np.multiply(confusion_matrix, cost_matrix.as_array()).sum()
        return -my_reward

    opt = scipy.optimize.minimize_scalar(
        negative_reward, bounds=[0, 1], method="bounded", options={"maxiter": 100}
    )
    return {group: _threshold_from_tpr(roc[group], opt.x, rng=rng) for group in groups}
