import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['text.latex.preamble'] = [
    r'\usepackage{amsmath}',
    r'\usepackage{amssymb}']

import numpy as np
from utils_2 import simul_x_y_a, plot_decision, reductions_prob
from sklearn.linear_model import LogisticRegression
from fairlearn.reductions import ExponentiatedGradient, DemographicParity
from metrics import accuracy_trace

classifiers = []
names = []
powers = [1]

classifiers.append(LogisticRegression(solver='liblinear', fit_intercept=True))
names.append('logistic')



text_size = 20


seed = 1444
np.random.seed(seed)

mu_mult = 2.
cov_mult = 1.
skew = 5.
rotate = 0.
train_prop_mtx = [[0.4, 0.1],[0.4, 0.1]]
lp = 1
train_x, train_a, train_y = simul_x_y_a(train_prop_mtx, n=2000, mu_mult=mu_mult, 
                                        cov_mult=cov_mult, skew=skew, rotate=rotate)
test_prop_mtx = [[0.25, 0.25],[0.25, 0.25]]
test_x, test_a, test_y = simul_x_y_a(test_prop_mtx, n=5000, mu_mult=mu_mult, 
                                     cov_mult=cov_mult, skew=skew, rotate=rotate)

plot = False
switch = {}
acc_dict = {}
for cl, nc in zip(classifiers, names):
    for p in powers:
        
        name = nc + '-' + str(p)
        if p == 2:
            train_x_p = np.hstack((train_x, train_x**2))
            test_x_p = np.hstack((test_x, test_x**2))
        else:
            train_x_p = np.copy(train_x)
            test_x_p = np.copy(test_x)
        
        ## Base classifier
        base = cl
        base.fit(train_x_p, train_y)
        base_predict = base.predict(test_x_p)
           
        ## Reductions NO fairnesss
        eps_base = 10.
        constraint = DemographicParity()
        classifier = cl
        mitigator_base = ExponentiatedGradient(classifier, constraint, eps=eps_base, T=25)
        mitigator_base.fit(train_x_p, train_y, sensitive_features=train_a)
        y_pred_mitigated_base = mitigator_base.predict(test_x_p)
        acc_base, p_minor_grid = accuracy_trace(test_y, y_pred_mitigated_base, test_a,\
             verbose=True, label_protected=lp)
        acc_dict[name] = acc_base
        
        ## Reduction classifier
        eps = 0.1
        constraint = DemographicParity()
        classifier = cl
        mitigator = ExponentiatedGradient(classifier, constraint, eps=eps, T=25)
        mitigator.fit(train_x_p, train_y, sensitive_features=train_a)
        y_pred_mitigated = mitigator.predict(test_x_p)
        acc_fair, _ = accuracy_trace(test_y, y_pred_mitigated, test_a, verbose=True, label_protected=lp)
        acc_dict[name + '-fair'] = acc_fair
        
        switch_idx = (acc_fair > acc_base).argmax()
        switch[name] = [np.round(p_minor_grid[switch_idx],2), np.round(acc_fair[switch_idx],3)]
        
        ## Plot decisions
        if plot:
            if name.split(' ')[0] != 'svm':
                plot_decision(test_x, test_a, test_y, lambda x: base.predict_proba(x**p)[:,1], title='Base ' + name)
                plot_decision(test_x, test_a, test_y, lambda x: reductions_prob(mitigator, x**p, 20), title='Reductions ' + name)
            else:
                print('\n', name, 'does not support probabilities')
            
#print(switch)

style = {"linewidth":2, "markeredgewidth":5, "linestyle":':'}

prop_dict = {'logistic-1':['green',(5,0),'o', r'Baseline on $P^\star$'],
            'logistic-1-fair':['magenta',(3,3,2,2),'x', r'Fair on $P^\star$'],
            }

x_ind = p_minor_grid
for m in acc_dict:
    color = prop_dict[m][0]
    dash_style = prop_dict[m][1]
    marker = prop_dict[m][2]
    legend = prop_dict[m][3]
    data = acc_dict[m]
    
    
    plt.plot(x_ind, data, color=color, label=legend, dashes=dash_style, marker=marker, **style)
    plt.legend(loc='lower left', numpoints=1, fontsize=15)
    
plt.tick_params('both', labelsize = 14)
plt.grid()
plt.xlabel(r'$p_{\text{minor}}$ in $P^\star$', size=18, labelpad=5)
plt.ylabel('Accuracy', size=18, labelpad=0)
plt.show()



