#!/usr/bin/env python3

import numpy as np

from setting import Setting
from system import System

import responses
import classifiers
import plot

################################################################################
# Default setting used in paper (Fig. 3)

# W_1 - W_0 vs phi
#
#           _____<___
#          /
# 0 ------s---------
#        /
#       /
# -->--'
#
# U01 > U11 > U10 > U00

s1 = Setting(
    mu=np.array([0.5, 0.5]),
    U=np.array([
        [0.1, 5.5],
        [0.5, 1.0]
    ]),
    V=np.array([
        [0.5, -0.5],
        [-0.25, 1.0]
    ])
)
assert s1.U[0,1] > s1.U[1,1] > s1.U[1,0] > s1.U[0,0]

# s1.sanity_check()

# mu_1 = 0.6
s1_6 = s1.copy(mu=np.array([0.6, 0.4]))
# mu_1 = 0.7
s1_7 = s1.copy(mu=np.array([0.7, 0.3]))
# mu_1 = 0.8
s1_8 = s1.copy(mu=np.array([0.8, 0.2]))
# mu_1 = 0.9
s1_9 = s1.copy(mu=np.array([0.9, 0.1]))

################################################################################
# Poster / Slides example

# W_1 - W_0 vs phi
#           _<_
#          /   \
# 0 ------s-----------
#        /
#       /
# -->--'
#
# U01 > U11 > U10
#       U00 = U10

s_farm_fish = Setting(
    mu=np.array([0.5, 0.5]),
    U=np.array([
        # no loan; loan
        [1, 3], # fish
        [1, 2]  # farm
    ]),
    V=np.array([
        [0, -1],
        [0, 3]
    ])
)

# s_farm_fish.sanity_check()

################################################################################
# ensure both unstable and stable hyperplanes

# W_1 - W_0 vs phi
#           _<_
#          /   \
# 0 ------s-----u-----
#        /       `->--
#       /
# -->--'
#
# U01 > U11 > U10
#       U00 > U10

s2 = Setting(
    mu=np.array([0.5, 0.5]),
    U=np.array([
        [0.5, 1.5],
        [0.1, 1.0]
    ]),
    V=np.array([
        [1, 0],
        [0, 1]
    ])
)
assert s2.U[0,1] > s2.U[1,1] > s2.U[1,0]
assert s2.U[0,0] > s2.U[1,0]

# s2.sanity_check()

################################################################################
# ensure only unstable hyperplane

# W_1 - W_0 vs phi
#   _<__
#       \
# 0 -----u---------
#         \
#          \
#           '-->---
#
# U00 > U10
# U11 > U10
# U11 > U01

s3 = Setting(
    mu=np.array([0.5, 0.5]),
    U=np.array([
        [0.5, 0.5],
        [0.1, 1.5]
    ]),
    V=np.array([
        [10, 0],
        [1, 1.5]
    ])
)
assert s3.U[0,0] > s3.U[1,0]
assert s3.U[1,1] > s3.U[1,0]
assert s3.U[1,1] > s3.U[0,1]

# s3.sanity_check()

################################################################################
# Model of Zhang et. al. (2020)
# How Do Fair Decisions Fare in Long-term Qualification?

s_markov = Setting(
    mu=np.array([0.5, 0.5]),

    # classifiers. Use U as transition matrix T.
    # (proportion of population with given outcome becoming qualified)
    # cf. Condition 1B of referenced paper
    U = np.array([
        [0.2, 0.5], # true negative, false positive
        [0.1, 0.8]  # false negative, true positive
    ]),
    V=np.array([
        [0, -1],
        [0, 1.3]
    ])
)

################################################################################
# Model of Coate & Loury (1993)
# Will Affirmative-Action Policies Eliminate Negative Stereotypes?

s_best_response = Setting(
    mu=np.array([0.5, 0.5]),
    # c.f. https://inequality.stanford.edu/sites/default/files/media/_media/pdf/Reference%20Media/Coate%20and%20Loury_1993_Discrimination%20and%20Prejudice.pdf
    # Assumes U01 > U11 to satisfy Beta(0) = 0
    # Assumes U00 > U10 to satisfy Beta(1) = 0

    # Use U as payoff matrix
    U=np.array([
        [0.5, 2.5],
        [0.1, 2.0]
    ]),
    V=np.array([
        [0, -500],
        [0, 1.0]
    ]),
    q0_mean=-1.5,
    q1_mean=1.5
)
assert s_best_response.U[0,1] > s_best_response.U[1,1]
assert s_best_response.U[0,0] > s_best_response.U[1,0]

# responses.plot_loury(s_best_response)

################################################################################

def display(setting, response_func, classifier_funcs, bgvars=(0, 1, 1, 1),
        res=64, strength_parameter=0.1, rate_limit=1, savename=None):
    '''
    setting: Setting object comprising parameters
    response_func: function from `responses` to model population response
    classifier_funcs: 4-tuple with each entry either 0 or 1 to switch off/on
                      use of each classifier separately: (DP, GI, FC, LZ)
    bgvars: shaded background for group 1 specific variables to display
             in separate rows:
             (None, Acceptance Rate, False Positive Rate, False Negative Rate)
    res: number of points per axis to calculate vector field
    strength_parameter: to use with feedback control
    rate_limit: max acceptance rate any classifier may allow
    savename: name to save file under in `images/` directory
    kwargs: forwarded to `plot.py` function `plot`
    '''

    systems = [
        System(
            'Demographic Parity', setting,
            classifiers.demographic_parity_bayes,
            response_func,
            res=res,
            cls_kwargs=dict(rate_limit=rate_limit)
        )
    ] * classifier_funcs[0] + [
        System(
            'Group-Independent Policy', setting,
            classifiers.group_independent_bayes,
            response_func,
            res=res,
            cls_kwargs=dict(rate_limit=rate_limit)
        )
    ] * classifier_funcs[1] + [
        System(
            r'($\epsilon =' + f'{strength_parameter}' + '$)' + '\n Feedback Control', setting,
            classifiers.feedback_control_perturbed_bayes,
            response_func,
            res=res,
            cls_kwargs=dict(
                rate_limit=rate_limit,
                strength_parameter=strength_parameter
            )
        )
    ] * classifier_funcs[2] + [
        System(
            'Laissez-Faire', setting,
            classifiers.laissez_faire_bayes,
            response_func,
            res=res,
            cls_kwargs=dict(rate_limit=rate_limit)
        )
    ] * classifier_funcs[3]

    for s in systems:

        print(s.name.replace('\n', ' '))
        s.calculate()

    if sum(bgvars) == 1:
        if sum(classifier_funcs) == 1:
            plot.plot_single(systems, savename, bgvars)
        else:
            plot.plot_row(systems, savename, bgvars)
    elif sum(classifier_funcs) == 1:
        plot.plot_vars(systems, savename, bgvars)
    else:
        plot.plot_grid(systems, savename, bgvars)

def output_parameters(setting, U_as='U'):
    mu, U, V = setting.mu, setting.U, setting.V

    return r'''
\begin{align*}
    &~\begin{bmatrix}
    \mu_1 = ''' + str(mu[0]) + r''' & \mu_2 = ''' + str(mu[1]) + r'''
    \end{bmatrix}''' + ((
        r'''\\
        &\begin{bmatrix}
        '''+U_as+r'''_{ 0, \hat{0} } = ''' + str(U[0,0]) + r''' & '''+U_as+r'''_{ 0, \hat{1} } = ''' + str(U[0,1]) + r''' \\
        '''+U_as+r'''_{ 1, \hat{0} } = ''' + str(U[1,0]) + r''' & '''+U_as+r'''_{ 1, \hat{1} } = ''' + str(U[1,1]) + r'''
        \end{bmatrix}''') if U_as is not None else '') + r'''\\
    &\begin{bmatrix}
    V_{ 0, \hat{0} } = ''' + str(V[0,0]) + r''' & V_{ 0, \hat{1} } = ''' + str(V[0,1]) + r''' \\
    V_{ 1, \hat{0} } = ''' + str(V[1,0]) + r''' & V_{ 1, \hat{1} } = ''' + str(V[1,1]) + r'''
    \end{bmatrix}
\end{align*}
'''

################################################################################
# Outputs

#-------------------------------------------------------------------------------
# document parameters

print(output_parameters(s1))
s1.plot_w1w0('s1_w1w0')

print(output_parameters(s2))
s2.plot_w1w0('s2_w1w0')

print(output_parameters(s3))
s3.plot_w1w0('s3_w1w0')


#-------------------------------------------------------------------------------
# Default setting; no limitation on acceptance rate

# display(s1, responses.replicator_equation, (0, 1, 0, 0), res=16, strength_parameter=0.1, bgvars=(1, 0, 0, 0))
display(s1, responses.replicator_equation, (1, 1, 1, 1), res=64, strength_parameter=0.1, bgvars=(0, 1, 0, 0), savename='final')

display(s1, responses.replicator_equation, (1, 1, 1, 1), res=64, strength_parameter=0.1, savename='s1')
display(s2, responses.replicator_equation, (1, 1, 1, 1), res=64, strength_parameter=0.1, savename='s2')
display(s3, responses.replicator_equation, (1, 1, 1, 1), res=64, strength_parameter=(-0.1), savename='s3')

#-------------------------------------------------------------------------------
# Group Sizes

# display(s1_6, responses.replicator_equation, (1, 1, 1, 1), res=64, strength_parameter=0.1, savename='s1_6')
display(s1_7, responses.replicator_equation, (1, 1, 1, 1), res=64, strength_parameter=0.1, savename='s1_7')
# display(s1_8, responses.replicator_equation, (1, 1, 1, 1), res=64, strength_parameter=0.1, savename='s1_8')
display(s1_9, responses.replicator_equation, (1, 1, 1, 1), res=64, strength_parameter=0.1, savename='s1_9')

#-------------------------------------------------------------------------------
# resource constraints

# Weak constraints
display(s1, responses.replicator_equation, (1, 1, 1, 1), res=64, strength_parameter=0.1, rate_limit=0.6, savename='s1_limit_6')

# Strong constraints
display(s1, responses.replicator_equation, (1, 1, 1, 1), res=64, strength_parameter=0.1, rate_limit=0.3, savename='s1_limit_3')

################################################################################
# Markov:

print(output_parameters(s_markov, U_as='T'))

display(s_markov, responses.markov, (1, 1, 0, 1), res=16, savename='s_markov')

################################################################################
# Coate&Loury:

print(output_parameters(s_best_response, U_as=None))
responses.plot_loury(s_best_response, 's_loury_s_phi')

display(s_best_response, responses.best_response, (1, 1, 0, 1), res=64, savename='s_loury')
