import os, sys
import numpy as np
import math
import pandas as pd
from scipy.stats import chi2
from random import seed, shuffle
import matplotlib.pyplot as plt
from scipy.optimize import minimize,root_scalar
from scipy.stats import multivariate_normal
import statsmodels.api as sm
from sklearn.metrics import accuracy_score,f1_score,precision_score,recall_score
SEED = 1122334455
seed(SEED)
np.random.seed(SEED)

#Some of the following codes are adapted from https://github.com/mbilalzafar/fair-classification

# This function is used to generate data.
def generate_synthetic_data(plot_data=False):
    """
        Code for generating the synthetic data.
        We will have two non-sensitive features and one sensitive feature.
        A sensitive feature value of 0.0 means the example is considered to be in protected group (e.g., female) and 1.0 means it's in non-protected group (e.g., male).
    """

    n_samples = 1000  # generate these many data points per class
    disc_factor = math.pi / 3.0  # this variable determines the initial discrimination in the data -- decraese it to generate more discrimination

    def gen_gaussian(mean_in, cov_in, class_label):
        nv = multivariate_normal(mean=mean_in, cov=cov_in, seed=SEED)
        X = nv.rvs(n_samples)
        y = np.ones(n_samples, dtype=float) * class_label
        return nv, X, y

    """ Generate the non-sensitive features randomly """
    # We will generate one gaussian cluster for each class
    mu1, sigma1 = [2, 2], [[5, 1], [1, 5]]
    mu2, sigma2 = [-2, -2], [[10, 1], [1, 3]]
    nv1, X1, y1 = gen_gaussian(mu1, sigma1, 1)  # positive class
    nv2, X2, y2 = gen_gaussian(mu2, sigma2, -1)  # negative class

    # join the posisitve and negative class clusters
    X = np.vstack((X1, X2))
    y = np.hstack((y1, y2))

    # shuffle the data
    perm = list(range(0, n_samples * 2))
    shuffle(perm)
    X = X[perm]
    y = y[perm]

    rotation_mult = np.array(
        [[math.cos(disc_factor), -math.sin(disc_factor)], [math.sin(disc_factor), math.cos(disc_factor)]])
    X_aux = np.dot(X, rotation_mult)

    """ Generate the sensitive feature here """
    x_control = []  # this array holds the sensitive feature value
    for i in range(0, len(X)):
        x = X_aux[i]

        # probability for each cluster that the point belongs to it
        p1 = nv1.pdf(x)
        p2 = nv2.pdf(x)

        # normalize the probabilities from 0 to 1
        s = p1 + p2
        p1 = p1 / s
        p2 = p2 / s

        r = np.random.uniform()  # generate a random number from 0 to 1

        if r < p1:  # the first cluster is the positive class
            x_control.append(1.0)  # 1.0 means its male
        else:
            x_control.append(0.0)  # 0.0 -> female

    x_control = np.array(x_control)

    """ Show the data """
    if plot_data:
        num_to_draw = 200  # we will only draw a small number of points to avoid clutter
        x_draw = X[:num_to_draw]
        y_draw = y[:num_to_draw]
        x_control_draw = x_control[:num_to_draw]

        X_s_0 = x_draw[x_control_draw == 0.0]
        X_s_1 = x_draw[x_control_draw == 1.0]
        y_s_0 = y_draw[x_control_draw == 0.0]
        y_s_1 = y_draw[x_control_draw == 1.0]
        plt.scatter(X_s_0[y_s_0 == 1.0][:, 0], X_s_0[y_s_0 == 1.0][:, 1], color='green', marker='x', s=30,
                    linewidth=1.5, label="Prot. +ve")
        plt.scatter(X_s_0[y_s_0 == -1.0][:, 0], X_s_0[y_s_0 == -1.0][:, 1], color='red', marker='x', s=30,
                    linewidth=1.5, label="Prot. -ve")
        plt.scatter(X_s_1[y_s_1 == 1.0][:, 0], X_s_1[y_s_1 == 1.0][:, 1], color='green', marker='o', facecolors='none',
                    s=30, label="Non-prot. +ve")
        plt.scatter(X_s_1[y_s_1 == -1.0][:, 0], X_s_1[y_s_1 == -1.0][:, 1], color='red', marker='o', facecolors='none',
                    s=30, label="Non-prot. -ve")

        plt.tick_params(axis='x', which='both', bottom='off', top='off',
                        labelbottom='off')  # dont need the ticks to see the data distribution
        plt.tick_params(axis='y', which='both', left='off', right='off', labelleft='off')
        plt.legend(loc=2, fontsize=15)
        plt.xlim((-15, 10))
        plt.ylim((-10, 15))

    x_control = {"s1": x_control}  # all the sensitive features are stored in a dictionary
    return X, y, x_control
def sample_cov(w,x_arr, x_control):
    n = x_arr.shape[0]
    m = x_arr.shape[1]
    try:
        d=x_control.shape[1]
    except:
        d=1
    meanc = np.mean(x_control,axis=0)
    arr=x_arr@w
    meanarr=np.mean(arr)
    res=np.zeros(shape=(n,d))
    for i in range(n):
        res[i] = (x_control[i] - meanc) * (arr[i]-meanarr)
    return res
def el_test(sv):
    res=sm.emplike.DescStat(sv)
    ci=res.ci_mean()
    p=res.test_mean(0)
    return ci, p

def _logistic_loss(w, X, y, return_arr=None):
    """Computes the logistic loss.

    This function is used from scikit-learn source code

    Parameters
    ----------
    w : ndarray, shape (n_features,) or (n_features + 1,)
        Coefficient vector.

    X : {array-like, sparse matrix}, shape (n_samples, n_features)
        Training data.

    y : ndarray, shape (n_samples,)
        Array of labels.

    """

    n=len(X[0])
    yz = y * np.dot(X, w[:n]) + w[n] * 0
    # Logistic loss is the negative of the log of the logistic function.
    if return_arr == True:
        out = -(log_logistic(yz))
    else:
        out = -np.sum(log_logistic(yz))
    return out

def log_logistic(X):
    """ This function is used from scikit-learn source code. Source link below """

    """Compute the log of the logistic function, ``log(1 / (1 + e ** -x))``.
	This implementation is numerically stable because it splits positive and
	negative values::
	    -log(1 + exp(-x_i))     if x_i > 0
	    x_i - log(1 + exp(x_i)) if x_i <= 0

	Parameters
	----------
	X: array-like, shape (M, N)
	    Argument to the logistic function

	Returns
	-------
	out: array, shape (M, N)
	    Log of the logistic function evaluated at every point in x
	Notes
	-----
	Source code at:
	https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/extmath.py
	-----

	See the blog post describing this implementation:
	http://fa.bianp.net/blog/2013/numerical-optimizers-for-logistic-regression/
	"""
    if X.ndim > 1: raise Exception("Array of samples cannot be more than 1-D!")
    out = np.empty_like(X)  # same dimensions and data types

    idx = X > 0
    out[idx] = -np.log(1.0 + np.exp(-X[idx]))
    out[~idx] = X[~idx] - np.log(1.0 + np.exp(X[~idx]))
    return out



def get_parameters(x, y, x_control, loss_function, el, alpha=0.05):

    max_iter = 100000

    if el == False:
        constraints = []
    else:
        constraints = get_constraint(x, x_control, alpha)
    w = minimize(fun=loss_function,
                 x0=-np.random.rand(x.shape[1] + 1, ),
                 args= (x, y),
                 method='SLSQP',
                 options={"maxiter": max_iter},
                 constraints=constraints
                 )

    return w.x

def get_constraint(x_train, x_control_train,  alpha):
    constraints = ({'type': 'ineq', 'fun': EL_constraint1,
              'args': (x_train, x_control_train, alpha)},
             {'type': 'eq', 'fun': EL_constraint2,
              'args': (x_train,x_control_train)},
              {'type': 'ineq', 'fun': EL_constraint3,
               'args': (x_train, x_control_train)},
              {'type': 'ineq', 'fun': EL_constraint4,
               'args': (x_train, x_control_train)}
             )

    return constraints

def log_star(x, n):
    if x >= 1 / n:
        res = math.log(x)
    else:
        res = math.log(1 / n) - 1.5 + 2 * n * x - (n * x) ** 2 / 2
    return res

def EL_constraint1(model, x_arr, x_control, alpha):
    m = x_arr.shape[1]
    n = x_arr.shape[0]
    s1 = 0
    meanc = np.mean(x_control)
    arr=x_arr@model[:m]
    meanarr=np.mean(arr)
    for i in range(n):

        arr1= (x_control[i] - meanc) * (arr[i]-meanarr)
        arr2 = np.dot(model[m], arr1)
        s1 = s1 + log_star(1 + arr2, n)
    thr = chi2(1).ppf(1 - alpha)
    if thr==math.inf:
        thr = 9999
    ans1 = thr - 2 * s1
    return ans1


def EL_constraint2(model, x_arr,x_control):
    m = x_arr.shape[1]
    n = x_arr.shape[0]

    s2 = 0
    meanc = np.mean(x_control)
    arr = x_arr @ model[:m]
    meanarr = np.mean(arr)
    for i in range(n):
        arr1 = (x_control[i] - meanc) * (arr[i]-meanarr)
        arr2 = np.dot(model[m], arr1)
        s2 = s2 + arr1 / (1 + arr2)
    ans2 = s2 / n
    return ans2

def EL_constraint3(model, x_arr, x_control):
    m = x_arr.shape[1]
    n = x_arr.shape[0]
    m1 = 0
    meanc = np.mean(x_control)
    arr = x_arr @ model[:m]
    meanarr = np.mean(arr)
    for i in range(n):
        arr1 = (x_control[i] - meanc) * (arr[i]-meanarr)
        m1 = max(m1, arr1)
    ans2 = model[m] - (1 - 1 / n) / (0 - m1)
    return ans2
def EL_constraint4(model, x_arr, x_control):
    m = x_arr.shape[1]
    n = x_arr.shape[0]
    m1 = 0
    meanc = np.mean(x_control)
    arr = x_arr @ model[:m]
    meanarr = np.mean(arr)
    for i in range(n):
        arr1 = (x_control[i] - meanc) * (arr[i]-meanarr)
        m1 = min(m1, arr1)
    ans2 = (1 - 1 / n) / (0- m1) - model[m]
    return ans2

def train_model(x_train, y_train, x_control_train,x_test, y_test, x_control_test,loss_function, el, alpha):

    w = get_parameters(x_train, y_train, x_control_train, loss_function,el, alpha)

    w=w[:-1]
    y_test_predicted = np.sign(np.dot(x_test, w))
    y_test=np.array(y_test)
    test_score = accuracy_score(y_test, y_test_predicted)
    f1=f1_score(y_test, y_test_predicted)
    tpr0 = np.mean(y_test_predicted[x_control_test == 0])
    tpr1= np.mean(y_test_predicted[x_control_test== 1])
    eo = tpr1-tpr0
    tpr0 = np.mean(y_test_predicted[(y_test == 1) & (x_control_test == 0)])
    tpr1 = np.mean(y_test_predicted[(y_test == 1) & (x_control_test == 1)])
    dp =  tpr1-tpr0

    y_test_predicted = np.sign(np.dot(x_test, w))
    y_test=np.array(y_test)
    test_score = accuracy_score(y_test, y_test_predicted)
    y_predicted = np.array(y_test_predicted)
    x_control_test=np.array(x_control_test)
    prot_pos=((y_predicted==1)*(x_control_test==0)).sum()/((x_control_test==0).sum())
    non_prot_pos=((y_predicted==1)*(x_control_test==1)).sum()/((x_control_test==1).sum())
    p_rule = (prot_pos / non_prot_pos) * 100.0
    return w, p_rule, test_score,f1,dp,eo

def sim():

    # This function is for the method EL-based fairness in 4.2 Trade-off between Accuracy and Fairness
    x_train = np.load('./sim_data/x_train.npy')
    y_train = np.load('./sim_data/y_train.npy')
    x_control_train = np.load('./sim_data/x_control_train.npy', allow_pickle=True).item()
    x_control_train=x_control_train['s1']
    x_test = np.load('./sim_data/x_test.npy')
    y_test = np.load('./sim_data/y_test.npy')
    x_control_test = np.load('./sim_data/x_control_test.npy', allow_pickle=True).item()
    x_control_test=x_control_test['s1']
    loss_function = _logistic_loss

    #"Unconstrained model"
    el = False
    alpha = 0.05
    w, p_rule, acc,f1,dp,eo = train_model(x_train, y_train, x_control_train, x_test, y_test,x_control_test,loss_function, el, alpha)
    w=w
    sv = sample_cov(w, x_test, x_control_test)
    ci, p = el_test(sv)

    el = True

    res = np.zeros((10, 8))
    print("Constrained model")
    i=0
    for alpha in np.array(range(10)) / 10:
        w, p_rule, acc,f1,dp,eo= train_model(x_train, y_train, x_control_train, x_test, y_test,x_control_test,loss_function, el, alpha)
        print(alpha, acc, p_rule,f1,dp,eo)

        sv = sample_cov(w, x_test, x_control_test)
        ci, p = el_test(sv)
        res[i]=[alpha, acc, p_rule,f1,dp,eo,ci[0],ci[1]]
        i=i+1
    res=pd.DataFrame(res)
    res.columns = ['alpha', 'acc', 'p', 'f1','dp', 'eo', 'lower', 'upper']

    print(res)
    return

def cp_sim(mu=[0,0,0],cov = [[2, 1 / 2, 1], [1 / 2, 1, 0], [1, 0, 1]],theta=[2,1],n = 200,alpha=0.05):
    # This function is for 4.1 Coverage Probability and Confidence Interval
    sigma = theta[0]* cov[0][1] +theta[1]*cov[0][2]
    N=2000
    np.random.seed(1)

    lb=0
    ub=0
    sum1 = 0
    length=0
    for j in range(N):
        if j%200==0:
            print(n,j,N)
        s, x1, x2 = np.random.multivariate_normal(mu, cov, n).T
        y=theta[0]*x1+theta[1]*x2
        meanc = np.mean(s)
        meany=np.mean(y)
        res = np.zeros(shape=(n,))
        for i in range(n):
            res[i] = (s[i] - meanc) * (y[i]-meany)
        res = sm.emplike.DescStat(res)
        res = res.ci_mean(sig=alpha)
        length=length+res[1]-res[0]
        lb=lb+res[0]
        ub=ub+res[1]
        if sigma>res[0] and sigma<res[1]:
            sum1=sum1+1
    print(sum1/N)
    return n,sum1/N,lb/N,ub/N,length/N

if __name__ == '__main__':
    sim()
    res=np.zeros([5,5])
    i=0
    for nn in [100,200,500,800,1200]:
        res[i]=cp_sim(mu=[0,0,0],cov = [[2, 1 / 2, 1], [1 / 2, 1, 0], [1, 0, 1]],theta=[3,2],n = nn,alpha=0.05)#cp_sim2(n = nn,alpha=0.05)#
        i=i+1
    print(res)
