from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
import numpy as np
import pandas as pd
import torch
from torch.utils import data
import sys
from torch import nn
import os
from optparse import OptionParser

#####################################
# ---- Train-test split
#####################################

def train_test_split(df, train_percent, seed=2022):
    """
    split the whole dataset as train set, validation set and test set according to the input percentage
    """
    np.random.seed(seed=seed)
    perm = np.random.permutation(df.index)
    m = len(df.index)
    train_end = int(train_percent * m)
    train = df.iloc[perm[:train_end]]
    test = df.iloc[perm[train_end:]]
    return train, test

#####################################
# ---- Train-test-validation split
#####################################

def train_val_test_split(df, train_percent, validate_percent, seed=2022):
    """
    split the whole dataset as train set, validation set and test set according to the input percentage
    
    """
    np.random.seed(seed=seed)
    perm = np.random.permutation(df.index)
    m = len(df.index)
    train_end = int(train_percent * m)
    validate_end = int(validate_percent * m) + train_end
    train = df.iloc[perm[:train_end]]
    validate = df.iloc[perm[train_end:validate_end]]
    test = df.iloc[perm[validate_end:]]
    return train, validate, test

#####################################
# ---- Bootstrap
#####################################

def bootstrap(df, train_percent, seed=2022):
    np.random.seed(seed=seed)
    indices = np.random.choice([_ for _ in range(df.shape[0])], int(df.shape[0] * train_percent))
    df_boot = df.iloc[indices]
    
    return df_boot

#####################################
# ---- Metrics
#####################################

# Accuracy
def my_acc(predictions, labels):
    numerator = (predictions == labels).sum()
    denominator = predictions.shape[0]
    return float(numerator) / float(denominator)

# Positive prediction rate
def my_ppr(predictions):
    numerator = (predictions == 1).sum()
    denominator = predictions.shape[0]
    return float(numerator) / float(denominator)

# True positive rate
def my_tpr(predictions, labels):
    numerator = np.multiply(predictions == 1, labels == 1).sum()
    denominator = (labels == 1).sum()
    return float(numerator) / float(denominator)

# False positive rate
def my_fpr(predictions, labels):
    numerator = np.multiply(predictions == 1, labels == 0).sum()
    denominator = (labels == 0).sum()
    return float(numerator) / float(denominator)

#####################################
# ---- Fairness metrics
#####################################

# Accuracy difference
def my_acc_diff(predictions, labels, group):
    acc_1 = my_acc(predictions[group == 1], labels[group == 1])
    acc_0 = my_acc(predictions[group == 0], labels[group == 0])
    return acc_1 - acc_0

# Positive prediction rate difference (a.k.a. statistical parity)
def my_ppr_diff(predictions, group):
    ppr_1 = my_ppr(predictions[group == 1])
    ppr_0 = my_ppr(predictions[group == 0])
    return ppr_1 - ppr_0

def my_ppr_new(predictions, group):
    ppr_1 = my_ppr(predictions[group == 1])
    ppr_0 = my_ppr(predictions[group == 0])
    ppr = my_ppr(predictions)
    return ppr_1 - ppr, ppr - ppr_0

# True positive rate difference (a.k.a. equal opportunity)
def my_tpr_diff(predictions, labels, group):
    tpr_1 = my_tpr(predictions[group == 1], labels[group == 1])
    tpr_0 = my_tpr(predictions[group == 0], labels[group == 0])
    return tpr_1 - tpr_0

# False positive rate difference
def my_fpr_diff(predictions, labels, group):
    fpr_1 = my_fpr(predictions[group == 1], labels[group == 1])
    fpr_0 = my_fpr(predictions[group == 0], labels[group == 0])
    return fpr_1 - fpr_0

#####################################
# ---- Experiment
#####################################

def experiment(eps, seed_for_reproduce, sqrt_rho):
    # Import data
    full_df = pd.read_csv('Adult/data_adult.csv')
    
    sensitive_label = 'gender_Male'

    RHO = sqrt_rho ** 2

    # Train-test split + bootstrap
    train_df = bootstrap(full_df, 0.5, seed = seed_for_reproduce)
    test_df = full_df
    
    x_train = train_df.drop(columns = ['label'])
    y_train = train_df['label']
    x_test = full_df.drop(columns = ['label'])
    y_test = full_df['label']

    y_true = torch.tensor(y_train.values, dtype = torch.float)
    y_true_test = torch.tensor(y_test.values, dtype = torch.float)

    n = train_df.shape[0]
    prob_major = (train_df[sensitive_label] == 1).mean()
    prob_minor = (train_df[sensitive_label] == 0).mean()

    sigmoid = nn.Sigmoid()

    logisticRegr = LogisticRegression(max_iter=2000, fit_intercept = False, penalty = 'none')
    logisticRegr.fit(x_train, y_train)

    def indicator_surrogate(x):
        return sigmoid(2 * x)

    class linear_model():
        def __init__(self, df):
            self.df = torch.tensor(df.values, dtype = torch.float)
        
        def __call__(self, para):
            return self.df.matmul(para)

    indicator_major = torch.tensor((train_df[sensitive_label] == 1).values, dtype = torch.float)
    indicator_minor = torch.tensor((train_df[sensitive_label] == 0).values, dtype = torch.float)

    def dual_constraint(scores, mu, nu):
        indicator_surrogates = indicator_surrogate(scores)
        losses = indicator_surrogates * indicator_major / prob_major - indicator_surrogates * indicator_minor / prob_minor - eps
        return torch.mean((losses - nu + 2 * mu) ** 2) / 4 / mu + (RHO / n - 1) * mu + nu

    def dual_constraint_exact(scores, mu, nu):
        indicators = (scores > 0)
        losses = indicators * indicator_major / prob_major - indicators * indicator_minor / prob_minor - eps
        return torch.mean((losses - nu + 2 * mu) ** 2) / 4 / mu + (RHO / n - 1) * mu + nu

    def evaluate_constraint(scores):
        indicator_surrogates = indicator_surrogate(scores)
        losses = indicator_surrogates * indicator_major / prob_major - indicator_surrogates * indicator_minor / prob_minor - eps
        return losses.mean()

    prob_major_test = (test_df[sensitive_label] == 1).mean()
    prob_minor_test = (test_df[sensitive_label] == 0).mean()
    indicator_major_test = torch.tensor((test_df[sensitive_label] == 1).values, dtype = torch.float)
    indicator_minor_test = torch.tensor((test_df[sensitive_label] == 0).values, dtype = torch.float)    
    def evaluate_constraint_test(scores):
        indicator_surrogates = indicator_surrogate(scores)
        losses = indicator_surrogates * indicator_major_test / prob_major_test - indicator_surrogates * indicator_minor_test / prob_minor_test - eps
        return losses.mean()


    ########################################
    # ---- Final version of algorithm ---- #
    ########################################

    num_epochs = 20000

    # Set seed
    torch.manual_seed(seed_for_reproduce)

    # Initialization
    theta = torch.tensor(logisticRegr.coef_[0], dtype = torch.float, requires_grad = True)
    mu = torch.tensor([0.1], requires_grad = True)
    nu = torch.tensor([0.1], requires_grad = True)
    lambd = torch.tensor([0.1], requires_grad = False)

    # Record best solution
    best_theta = torch.clone(theta)
    best_objective  = torch.tensor(9999.0)
    best_constraint = torch.tensor(0.0)
    best_error_train = torch.tensor(0.0)
    best_fairness_train = torch.tensor(0.0)
    best_error_test  = 0.0
    best_fairness_test  = 0.0

    # Logistic regression
    phi = linear_model(x_train)
    phi_test = linear_model(x_test)

    # Learning rates
    learning_rate_theta = 0.01
    learning_rate_mu = 0.01
    learning_rate_nu = 0.01
    learning_rate_lambda = 0.1
    num_epochs_sub = 10
    optimizer_theta = torch.optim.Adam([theta], lr=learning_rate_theta)
    optimizer_mu = torch.optim.Adam([mu], lr=learning_rate_mu)
    optimizer_nu = torch.optim.Adam([nu], lr=learning_rate_nu)

    for epoch in range(num_epochs):
        ##############################################
        # ---- Dual function sub-problem: start ---- #
        ##############################################

        for epoch_sub in range(num_epochs_sub):
            # Evaluate dual function
            scores = phi(theta)
            p = sigmoid(scores)
            objective  = -torch.mean(torch.log(p) * y_true + torch.log(1 - p) * (1 - y_true)) + 0.001 * theta.norm() ** 2
            dual_function = objective + lambd * dual_constraint(scores, mu, nu)

            # Backpropagation
            dual_function.backward(inputs = [theta, mu, nu])

            # Gradient descent
            optimizer_theta.step()
            optimizer_mu.step()
            with torch.no_grad():
                lambd.copy_(torch.max(torch.zeros(1), lambd))
            optimizer_nu.step()

            # Forget the gradient
            optimizer_theta.zero_grad()
            optimizer_mu.zero_grad()
            optimizer_nu.zero_grad()

        ##############################################
        # ---- Dual function sub-problem: end ------ #
        ##############################################

        # Update lambda
        with torch.no_grad():
            scores = phi(theta)
            lambd.copy_(torch.max(torch.zeros(1), lambd + learning_rate_lambda * dual_constraint_exact(scores, mu, nu)))

    # Prediction
    scores_train = phi(theta)
    scores_test  = phi_test(theta)
    train_predictions = (scores_train > 0).int().numpy()
    test_predictions  = (scores_test > 0).int().numpy()

    p_train = sigmoid(scores_train)
    objective_train  = -torch.mean(torch.log(p_train) * y_true + torch.log(1 - p_train) * (1 - y_true))
    p_test = sigmoid(scores_test)
    objective_test  = -torch.mean(torch.log(p_test) * y_true_test + torch.log(1 - p_test) * (1 - y_true_test))

    # Testing
    print(eps)
    print(seed_for_reproduce)
    print(1 - my_acc(train_predictions, train_df['label']))              # V3: Training error
    print(my_ppr_diff(train_predictions, train_df['gender_Male']) - eps) # V4: Training constraint violation
    print(objective_train.item())                                        # V5: Training objective
    print((objective_train + 0.001 * theta.norm() ** 2).item())          # V6: Training objective + penalty
    print(1 - my_acc(test_predictions, test_df['label']))                # V7: Test error
    print(my_ppr_diff(test_predictions, test_df['gender_Male']) - eps)   # V8: Test constraint violation
    print(objective_test.item())                                         # V9: Test objective
    print((objective_test + 0.001 * theta.norm() ** 2).item())           # V10:Test objective + penalty
    print(evaluate_constraint_test(scores_test).item())                  # V11:Test constraint


def parse_args():
    parser = OptionParser()

    parser.add_option("--eps", type="float", dest="eps", default=0.01)
    parser.add_option("--seed", type="int", dest="seed_for_reproduce", default=2022)
    parser.add_option("--rho", type="float", dest="sqrt_rho", default=2.0)

    (options, args) = parser.parse_args()

    return options

def main():
    options = parse_args()
    print(options)
    
    eps = options.eps
    seed_for_reproduce = options.seed_for_reproduce
    sqrt_rho = options.sqrt_rho
    
    experiment(eps, seed_for_reproduce, sqrt_rho)

if __name__ == '__main__':
    main()
