#!/usr/bin/env python

import numpy as np
from scipy import optimize as opt

from matplotlib import pyplot as plt

from util import dot

epsilon = 0.005
max_float = np.finfo(float).max
min_float = np.finfo(float).min

'''
A namespace for classifier functions

Signature for each function:

Args:
    setting: (Setting object) - current setting parameters
    state: (State object) - current state
    rate_limit: float: [0, 1] - max allowed acceptance rate

Returns:
    phi: (numpy array) - classifier's feature threshold for each group
'''

def group_independent_bayes(setting, state, rate_limit=1):
    '''
    Bayes-optimal classifier with same feature threshold for all groups
    '''
    V, Q1, Q0, inv_q1_q0, q0_mean, q1_mean, sigma, xi, sg = (
        setting.V, setting.Q1, setting.Q0, setting.inv_q1_q0,
        setting.q0_mean, setting.q1_mean, setting.sigma, setting.xi, state.sg
    )

    # Eq. 6
    nat_phi = inv_q1_q0( xi * (1 - state.avg) / state.avg ) * np.ones(state.sg.shape)

    # we naturally have stricter standards than limited resources require
    if dot(setting.beta(nat_phi, state), setting.mu) < rate_limit:
        return nat_phi

    # binary search for phi that achieves limiting acceptance rate
    l = -10 * sigma + q0_mean
    phi = 0
    r = 10 * sigma + q1_mean

    admitted = dot(setting.beta(phi, state), setting.mu)

    # difference between admission rate and rate_limit
    while abs(admitted - rate_limit) > epsilon:

        if admitted > rate_limit: # threshold too low

            l = phi
            phi = (phi + r) / 2

        else: # threshold too high

            r = phi
            phi = (phi + l) / 2

        admitted = dot(setting.beta(phi, state), setting.mu)

    return phi * np.ones(state.sg.shape)

############################################################################

def laissez_faire_bayes(setting, state, rate_limit=1):
    '''
    Bayes-optimal classifier with per-group thresholds (cf. Coate&Loury)
    '''
    V, Q1, Q0, inv_q1_q0, xi, sg = (
        setting.V, setting.Q1, setting.Q0, setting.inv_q1_q0, setting.xi, state.sg
    )

    thresholds = []

    # independent thresholds (must apply Eq. 6 for each group)
    for s in state.sg:

        thresholds.append(
            inv_q1_q0( setting.xi * (1 - s) / s )
        )

    nat_phi = np.array(thresholds).reshape(state.sg.shape)

    # we naturally have stricter standards than limited resources require
    if dot(setting.beta(nat_phi, state), setting.mu) < rate_limit:
        return nat_phi

    # We need to admit fewer people; raise threshold on group that
    # is currently less likely to be qualified if admitted.
    # ppr is positive predition rate (prob Y=1 | \hat{Y} = 1)

    ppr = setting.ppr(nat_phi, state)

    ppr_low = min(ppr) # ppr of lower group
    ppr_high = max(ppr)

    # group indices
    if ppr[0] < ppr[1]:
        low_g, high_g = 0, 1
    else:
        high_g, low_g = 0, 1

    # threshold for group with lower PPR to match PPR of other group
    low_g_phi_max = setting.inv_ppr(ppr_high, state)[low_g]

    # pack thresholds for equally high PPR into array
    phis = nat_phi[high_g] * np.ones(2)
    phis[low_g] = low_g_phi_max
    same_ppr_admitted = dot(setting.beta(phis, state), setting.mu)

    # solution does not need to change threshold for group with higher PPR as well
    if same_ppr_admitted < rate_limit:

        # binary search in phi for group with lower PPR only
        l = nat_phi[low_g]
        r = low_g_phi_max
        phi = (l + r) / 2

        # pack into array
        phis[low_g] = phi
        admitted = dot(setting.beta(phis, state), setting.mu)

        # difference between admission rate and rate_limit
        while abs(admitted - rate_limit) > epsilon:

            if admitted > rate_limit: # threshold too low

                l = phi
                phi = (phi + r) / 2

            else: # threshold too high

                r = phi
                phi = (phi + l) / 2

            # pack into array
            phis[low_g] = phi
            admitted = dot(setting.beta(phis, state), setting.mu)

        return phis

    # we must raise threshold on both groups to satisfy rate_limit
    # perform binary search in group-independent PPR, for which we can find phis

    l = ppr_high
    r = 1.0
    ppr = (l + r) / 2

    # find phi from ppr
    phi = setting.inv_ppr(ppr, state)
    admitted = dot(setting.beta(phi, state), setting.mu)

    # difference between admission rate and rate_limit
    while abs(admitted - rate_limit) > epsilon:

        # threshold too low; higher threshold implies higher ppr
        if admitted > rate_limit:

            l = ppr
            ppr = (ppr + r) / 2

        else: # threshold (ppr) too high

            r = ppr
            ppr = (ppr + l) / 2

        # find phi from ppr
        phi = setting.inv_ppr(ppr, state)
        admitted = dot(setting.beta(phi, state), setting.mu)

    return phi

############################################################################


def demographic_parity_bayes(setting, state, rate_limit=1):
    '''
    Bayes-optimal classifier contrained by demographic parity
    '''

    V, Q1, Q0, inv_q1_q0, sg, mu = (
        setting.V, setting.Q1, setting.Q0, setting.inv_q1_q0, state.sg, state.mu
    )

    x = np.linspace(-10, 10, 1000)

    def perturbed(state, gamma):
        '''get thresholds phi'''

        # Eq. 136a
        n = (V[0,0] - V[0,1] - gamma)
        d = (V[1,1] - V[1,0] + gamma)

        # n / d should range from 0 (phi = -inf) to inf (phi = inf)
        # we check boundary values at line 247
        with np.testing.suppress_warnings() as sup:

            sup.filter(RuntimeWarning, 'invalid value encountered in true_divide')

            if any(((n / d) * (1 - state.sg) / state.sg) < 0):
                return None

            return inv_q1_q0(
                (n / d) * (1 - state.sg) / state.sg
            )

    def diff(gamma_a):
        '''
        difference in acceptance rates
        '''

        gamma_b = - mu[0] / mu[1] * gamma_a # (Eq. 136b)

        gamma = np.array([gamma_a, gamma_b]).reshape((2,))

        phi = perturbed(state, gamma)

        if phi is None:
            return gamma_a

        beta = setting.beta(phi, state)

        return beta[0] - beta[1]

    def u(phi):
        return (
            dot(     Q0(phi),  (V[0,0] * (1 - sg))) + # utility term for true negative
            dot((1 - Q0(phi)), (V[0,1] * (1 - sg))) + # utility term for false positive
            dot(     Q1(phi),  (V[1,0] * sg)) + # utility term for false negative
            dot((1 - Q1(phi)), (V[1,1] * sg)) # utility term for true positive
        )

    def u_gamma(gamma_a):
        '''
        # Utility function to maximize subject to DP
        '''

        with np.testing.suppress_warnings() as sup:

            sup.filter(RuntimeWarning, 'overflow encountered in double_scalars')

            gamma_b = - mu[0] / mu[1] * gamma_a # (Eq. 136b)

        gamma = np.array([gamma_a, gamma_b])
        phi = perturbed(state, gamma)

        return u(phi)


    # root finding with scipy library code
    try:
        gamma_a = opt.bisect(diff, -1.5, 1.5)

        # x = np.linspace(-1.5, 1.5, 500)
        # plt.plot(x, [u(a) for a in x], label='u')
        # plt.plot(x, [diff(a) for a in x], label='v')
        # plt.plot([-1, 1], [0, 0])
        # plt.plot([gamma_a, gamma_a], [0, 0.5], label='gamma_a')
        # plt.legend()
        # plt.show()

    except RuntimeError:
        x = np.linspace(-1.5, 1.5, 500)
        plt.plot(x, [u(a) for a in x], label='u')
        plt.plot(x, [diff(a) for a in x], label='v')
        plt.legend()
        plt.show()

    # Handle boundaries

    gamma_b = - mu[0] / mu[1] * gamma_a # (Eq. 136b)
    phi = perturbed(state, np.array([gamma_a, gamma_b]))

    utility = u_gamma(gamma_a)
    if u(5) >= utility:
        phi = 5 * np.ones(2)
    if u(-5) > utility:
        phi = -5 * np.ones(2)

    # we naturally have stricter standards than limited resources require

    # print(dot(setting.beta(phi, state), setting.mu))
    if dot(setting.beta(phi, state), setting.mu) < rate_limit:
        return phi

    # we must admit 'rate_limit' from each group. Find the thresholds that do this.
    return setting.inv_beta(rate_limit, state)

############################################################################

def feedback_control_perturbed_bayes(setting, state, strength_parameter=0.1, rate_limit=1):
    '''
    Feedback Control mechanism
    '''

    V, Q1, Q0, inv_q1_q0, xi, sg, mu = (
        setting.V, setting.Q1, setting.Q0, setting.inv_q1_q0, setting.xi, state.sg, state.mu
    )

    # Eq. 6
    phi = inv_q1_q0( xi * (1 - state.avg) / state.avg ) * np.ones(state.sg.shape)

    # we do not naturally have stricter standards than limited resources require
    if not dot(setting.beta(phi, state), setting.mu) < rate_limit:

        # binary search
        l = -10
        phi = 0
        r = 10

        admitted = dot(setting.beta(phi, state), setting.mu)

        # difference between admission rate and rate_limit
        while abs(admitted - rate_limit) > epsilon:

            if admitted > rate_limit: # threshold too low

                l = phi
                phi = (phi + r) / 2

            else: # threshold too high

                r = phi
                phi = (phi + l) / 2

            admitted = dot(setting.beta(phi, state), setting.mu)

        phi = phi * np.ones(state.sg.shape)

    delta = state.delta

    for g in range(1, state.n+1):

        # (Eq. 26)
        prbn = 0 # perturbation for gth component of Phi vector
        for h in range(1, g):
            prbn -= delta[h-1] * mu[:h].sum()
        for h in range(g, state.n):
            prbn += delta[h-1] * mu[h:].sum()

        phi[g-1] += -(prbn / (state.sg[g-1] * (1 - state.sg[g-1]))) * strength_parameter

    return phi
