#!/usr/bin/env python

'''
A namespace for population response functions

Signature for each function:

Args:
    setting: (Setting object) - current setting parameters
    state: (State object) - current state
    phi: (numpy array) - classifier's feature threshold for each group

Returns:
    (State object): velocity vector $(s_g[t+1] - s_g[t] : g \in \mathcal{G})$
'''

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy import optimize as opt

from state import State
import util

epsilon = np.finfo(float).eps

mpl.rcParams.update({
    'font.family': 'serif',
    'font.serif': 'Nimbus Roman',
    'mathtext.fontset': 'cm',
    'mathtext.rm': 'serif',
    'pdf.fonttype': 42,
    'ps.fonttype': 42
})

def replicator_equation(setting, state, phi):
    '''Studied in the main paper.'''

    U, Q1, Q0, mu, sg = (
        setting.U, setting.Q1, setting.Q0,
        setting.mu, state.sg
    )

    # Qy is proportion of (y = qualified) agents that are rejected
    # It is used to interpolate from U[y,y_hat=1] to U[y,y_hat=0]
    W1 = U[1,1] + (U[1,0] - U[1,1]) * Q1(phi)
    W0 = U[0,1] + (U[0,0] - U[0,1]) * Q0(phi)

    # sg is proportion qualified in each group
    # Wg is average fitness of each group
    Wg = sg * W1 + (1 - sg) * W0

    # velocity vector in state space, where unit time separates each round
    return State(mu, (sg * W1 / Wg) - sg)

def markov(setting, state, phi):
    '''
    Used by Zhang et. al.
    We use U as the transition matrix T,
    which yields proportion becoming qualified conditioned on each outcome (Y, hat{Y})
    in previous round
    '''

    V, U, Q0, Q1, mu, sg = (
        setting.V, setting.U, setting.Q0,
        setting.Q1, setting.mu, state.sg
    )

    outcomes = np.array([
        [
            Q0(phi) * (1 - sg),  # true negative fraction per group
            (1 - Q0(phi)) * (1 - sg), # false positive fraction per group
        ],
        [
            Q1(phi) * sg, # false negative fraction per group
            (1 - Q1(phi)) * sg # true positive fraction per group
        ]
    ])
    outcomes = outcomes.transpose((2, 0, 1))

    new_sg = [0, 0]
    for g in range(2):
        new_sg[g] = util.dot(U, outcomes[g])

    return State(mu, np.array(new_sg) - sg)

def best_response(setting, state, phi):
    '''Coate & Loury'''

    sg = state.sg
    mu = setting.mu
    U, Q0, Q1, inv_q1_q0, q1_mean, q0_mean, xi = (
        setting.U, setting.Q0, setting.Q1, setting.inv_q1_q0,
        setting.q1_mean, setting.q0_mean, setting.xi
    )

    def q0(x):
        return util.gaussian(x, q0_mean, 1)

    def q1(x):
        return util.gaussian(x, q1_mean, 1)

    # Coate and Loury assume group-independent distribution of non-negative
    # investment costs for agents to become qualified
    # AGENTS PRIVATELY KNOW THEIR COSTS
    #
    # We represent this distribution with a CDF corresponding to the
    # fraction of agents that invest in qualification for a given gross
    # benefit of doing so.
    #
    # CDF is G(b) such that G(x) = 0 for x <= 0

    def eu_phi(phi):
        '''expected benefit as function of threshold'''
        return (
            - ( # unqualified
                Q0(phi) * U[0,0] +      # true negative
                (1 - Q0(phi)) * U[0,1]  # false positive
            ) +
            ( # qualified
                Q1(phi) * U[1,0] +      # false negative
                (1 - Q1(phi)) * U[1,1]  # true positive
            )
        )

    def threshold(s):
        '''classifier acceptance rate as function of qualification rate'''
        # (Eq. 6)
        return inv_q1_q0( xi * (1 - s) / s)

    # variable domains
    s_dom = [0 + epsilon, 1 - epsilon]
    phi_dom = [threshold(a) for a in s_dom]

    res = opt.minimize(lambda x: -eu_phi(x), 0)
    max_eu = -res.fun
    crit_phi = res.x

    # invertible domain
    eu_dom = [eu_phi(phi_dom[0]), max_eu]
    phi_dom = [crit_phi, phi_dom[0]]

    min_eu = min(eu_dom)

    # x = np.linspace(min(phi_dom), max(phi_dom), 100)
    # plt.plot(x, eu_phi(x))
    # plt.plot([phi_dom[0], phi_dom[0]], [max_eu, max_eu -1])
    # plt.plot([phi_dom[0], phi_dom[0]+1], [max_eu, max_eu])
    # plt.plot([phi_dom[1], phi_dom[1]], [min_eu, min_eu+1])
    # plt.plot([phi_dom[1], phi_dom[1]-1], [min_eu, min_eu])
    # plt.show()

    def G(eu):
        '''
        Cumulative distribution function:
        proportion of population choosing to qualify
        given expected benefit of qualification
        '''

        def phi_eu(eu):
            '''threshold as function of expected benefit'''

            # binary search

            if not (min_eu <= eu <= max_eu):
                raise ValueError(f'{min_eu}, {eu}, {max_eu}')

            r = max(phi_dom)
            l = min(phi_dom)

            phi = (l + r) / 2

            current_eu = eu_phi(phi)
            while abs(current_eu - eu) > 0.0001:

                if current_eu > eu:
                    l = phi
                    phi = (phi + r) / 2
                else:
                    r = phi
                    phi = (phi + l) / 2
                current_eu = eu_phi(phi)

            return phi


        def inv_threshold(phi):
            r = q1(phi) / q0(phi) / xi
            s = 1 / (r + 1)

            return s + np.sin((phi + 0.5 ) * 5) * util.gaussian(phi, 2, 0.5) / 10

        return inv_threshold(phi_eu(eu))

    # contributions to expected utility if choosing to become:
    eu = np.array([
        ( # unqualified
            Q0(phi) * U[0,0] +      # true negative
            (1 - Q0(phi)) * U[0,1]  # false positive
        ),
        ( # qualified
            Q1(phi) * U[1,0] +      # false negative
            (1 - Q1(phi)) * U[1,1]  # true positive
        )
    ])

    new_sg = [0, 0]
    for g in range(2):
        new_sg[g] = G(eu[1,g] - eu[0,g])[0]

    return State(mu, np.array(new_sg) - sg)



def plot_loury(setting, savename=None):

    inv_q1_q0, xi, Q0, Q1, U = (
        setting.inv_q1_q0, setting.xi, setting.Q0, setting.Q1, setting.U
    )

    def threshold(s):
        '''classifier acceptance rate as function of qualification rate'''
        # (Eq. 6)
        return inv_q1_q0(xi * (1 - s) / s)

    def eu_phi(phi):
        '''expected benefit as function of threshold'''
        return (
            - ( # unqualified
                Q0(phi) * U[0,0] +      # true negative
                (1 - Q0(phi)) * U[0,1]  # false positive
            ) +
            ( # qualified
                Q1(phi) * U[1,0] +      # false negative
                (1 - Q1(phi)) * U[1,1]  # true positive
            )
        )

    s_dom = np.linspace(0.001, 0.999, 100)
    phi_dom = threshold(s_dom)

    # response is state independent, returning velocity
    # so we can get the final state by setting inital state as 0
    state = State(setting.mu, (0, 0))

    s =best_response(setting, state, phi_dom[0] * np.ones(2))

    response_states = [
        best_response(setting, state, phi * np.ones(2))[0] for phi in phi_dom
    ]

    plt.figure(figsize=(3,3))

    plt.plot(phi_dom, response_states, label='$s[t+1](\phi[t])$ \n (Population)')
    plt.plot(phi_dom, s_dom, label='$\phi[t](s[t])$ \n (Classifier)')
    plt.xlabel('Threshold $\phi$')
    plt.ylabel('Qualification rate $s$')
    plt.xlim(-0.5, 4)
    plt.legend()

    if savename is None:
        plt.show()
    else:
        filename = f'images/{savename}.pdf'
        print('saving', filename)
        plt.savefig(filename, bbox_inches='tight')
        plt.close()
