import numpy as np
import pandas as pd

from scipy.optimize import linprog
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, f1_score
from sklearn.model_selection import train_test_split

def if_metric(x, y_pred, tol=1e-6):
    n = len(x)
    max_ratio = 0
    total_ratio = 0
    count = 0

    for i in range(n):
        for j in range(i + 1, n):
            dist = np.linalg.norm(x[i] - x[j])
            if dist < tol:
                continue  # avoid dividing by near-zero
            pred_diff = np.abs(y_pred[i] - y_pred[j])
            ratio = pred_diff / dist
            max_ratio = max(max_ratio, ratio)
            total_ratio += ratio
            count += 1

    avg_ratio = total_ratio / count if count > 0 else float('nan')
    return max_ratio, avg_ratio

def F_0(score, y, t):
    return np.sum((score <= t) & (y == 0)) / score.shape[0]

def F_1(score, y, t):
    return np.sum((score <= t) & (y == 1)) / score.shape[0]

def expected_outcome(p, s):
    k = int((s - 1.0) / 0.5)
    return sum(p[0:k] * 0.5)+ p[k] * (s - 1.0 - k * 0.5)

def cost(t, s):
    if s >= t:
        return 0
    elif s < t - 1.0:
        return 0
    else:
        return (t - s) ** 2

def expected_cost(p, s):
    if s <= 0.5:
        return (1 / 3) * p[1] * (1.0 - (1.0 - s) ** 3)
    elif 0.5 < s <= 1.0:
        return (1 / 3) * p[1] * ((1.5 - s) ** 3 - (1.0 - s) ** 3) + (1 / 3) * p[2] * (1.0 - (1.5 - s) ** 3)
    
    k = int((s - 1.0) / 0.5)
    s_k_1 = 1.0 + k * 0.5
    e_k_1 = 1.0 + (k + 1) * 0.5
    s_k_2 = 1.0 + (k + 2) * 0.5
    e_k_2 = 1.0 + (k + 2) * 0.5
    
    return (1 / 3) * p[k] * ((e_k_1 - s) ** 3) + (1 / 3) * p[k + 1] * ((s_k_2 - s) ** 3 - (e_k_1 - s) ** 3) + \
        (1 / 3) * p[k + 2] * (1.0 - (e_k_2 - s) ** 3)        

def experiment(upper_bound):
    data = pd.read_csv("./data/scores.csv")
    s = np.array(data["Score"])
    y = np.array(data["Default"])
    a = np.array(data["Race"])
    
    mask = (a == "Black") | (a == "Non-Hispanic white")
    s = s[mask]
    y = y[mask]
    a = a[mask]
    
    base_accs = []
    base_ifs = []
    base_sps = []
    random_accs = []
    random_ifs = []
    random_sps = []
    for seed in range(42, 47):
        np.random.seed(seed)
        
        train_y, valid_y, train_s, valid_s, train_a, valid_a = train_test_split(y, s, a, test_size=0.4, random_state=seed)
        valid_y, test_y, valid_s, test_s, valid_a, test_a = train_test_split(valid_y, valid_s, valid_a, test_size=0.5, random_state=seed)
        
        train_score = train_s
        valid_score = valid_s
        test_score = test_s
        
        print("### deterministic ###")
        ### find best deterministic model
        best_f1 = 0
        for t in np.arange(1.0, 101.0, 0.5):
            train_valid_s = np.concatenate([train_s, valid_s], axis=-1)
            train_valid_y = np.concatenate([train_y, valid_y], axis=-1)
            train_valid_pred = ((train_valid_s + 1.0) >= t).astype(int)
            acc = f1_score(train_valid_y, train_valid_pred)
            if acc > best_f1:
                best_f1 = acc
                best_t = t
        
        ### evaluate the deterministic model
        test_pred = ((test_score + 1.0) >= best_t).astype(int)
        test_acc = f1_score(test_y, test_pred, average="macro")
        print("test acc = {}".format(test_acc))
        base_accs.append(test_acc)
        
        #test_cost = np.zeros_like(test_score)
        #test_cost = test_score[test_score < best_t]
        #test_cost = best_t - test_score
        #test_cost = test_cost[test_cost <= 0.1]
        #test_cost = 100 * (test_cost ** 2)
        #if_ratio, __ = if_metric(test_score, test_cost)
        #print("if ratio = {}".format(if_ratio))
        test_cost = [cost(best_t, s) for s in test_score]
        if_ratio, __ = if_metric(test_s, test_cost)
        print("if ratio = {}".format(if_ratio))
        base_ifs.append(if_ratio)
        #base_ifs.append(if_ratio)
        #if_ratio, __ = if_metric(test_score, test_pred)
        #print("if ratio = {}".format(if_ratio))
        #base_ifs.append(if_ratio)
        
        test_pred_groups = []
        test_pred_groups_class = []
        for race in np.unique(test_a):
            test_pred_groups.append(test_pred[(test_a == race) & (test_y == 1)])
            test_pred_groups_class.append(test_pred[(test_a == race) & (test_y == 0)])
        test_positive_rates = [np.sum(test_pred_group) / len(test_pred_group) for test_pred_group in test_pred_groups]
        test_positive_rates_class = [np.sum(test_pred_group_class) / len(test_pred_group_class) for test_pred_group_class in test_pred_groups_class]
        diff = max(abs(test_positive_rates[0] - test_positive_rates[1]), abs(test_positive_rates_class[0] - test_positive_rates_class[0]))
        print("group fairness = {}".format(diff))
        base_sps.append(diff)
        
        print("### random ###")
        ### find random classifier
        p = np.zeros(200)
        A = np.zeros(200)
        for k in range(200):
            s_k = 1.0 + k * 0.5
            e_k = 1.0 + (k + 1) * 0.5
            
            A[k] = 0.5 * (F_1(train_s, train_y, s_k - 1.0) - F_1(train_s, train_y, 0.0) + \
                F_0(train_s, train_y, 100.0) - F_0(train_s, train_y, s_k - 1.0) + \
                    F_1(train_s, train_y, e_k - 1.0) - F_1(train_s, train_y, 0.0) + \
                        F_0(train_s, train_y, 100.0) - F_0(train_s, train_y, e_k - 1.0)) * 0.5
    

        c = A
        A_eq = np.ones((1, 200))
        b_eq = [2]
        L = upper_bound
        bounds = [(0, L) for _ in range(200)]
        res = linprog(c, A_eq=A_eq, b_eq=b_eq, bounds=bounds, method='highs')
        p = res.x
        
        bin_width = 0.5
        bin_edges = np.linspace(1.0, 101.0, len(p) + 1)
        prob_mass = p * bin_width
        prob_mass /= prob_mass.sum()

        bin_indices = np.random.choice(len(p), size=len(test_score), p=prob_mass)
        lefts = bin_edges[bin_indices]
        samples = lefts + np.random.rand(len(test_score)) * bin_width
        
        test_pred = ((test_score + 1.0) > samples).astype(int)
        acc = f1_score(test_y, test_pred, average="macro")
        print("test acc = {}".format(acc))
        random_accs.append(acc)
        
        #test_pred_expected = [expected_outcome(p, s + 1.0) for s in test_score]
        #if_ratio, __ = if_metric(test_s, test_pred_expected)
        #print("if ratio = {}".format(if_ratio))
        #random_ifs.append(if_ratio)
        test_cost_expected = [expected_cost(p, s) for s in test_score]
        if_ratio, __ = if_metric(test_s, test_cost_expected)
        print("if ratio = {}".format(if_ratio))
        random_ifs.append(if_ratio)
        
        test_pred_groups = []
        test_pred_groups_class = []
        for race in np.unique(test_a):
            test_pred_groups.append(test_pred[(test_a == race) & (test_y == 1)])
            test_pred_groups_class.append(test_pred[(test_a == race) & (test_y == 0)])
        test_positive_rates = [np.sum(test_pred_group) / len(test_pred_group) for test_pred_group in test_pred_groups]
        test_positive_rates_class = [np.sum(test_pred_group_class) / len(test_pred_group_class) for test_pred_group_class in test_pred_groups_class]
        diff = max(abs(test_positive_rates[0] - test_positive_rates[1]), abs(test_positive_rates_class[0] - test_positive_rates_class[0]))
        print("group fairness = {}".format(diff))
        random_sps.append(diff)
    
    return base_accs, base_ifs, base_sps, random_accs, random_ifs, random_sps

def main():
    upper_bounds = [1, 0.5, 0.25, 0.1, 0.05]
    
    with open("credit_results.txt", "w") as f:
        pass
    
    for upper_bound in upper_bounds:
        base_accs, base_ifs, base_sps, random_accs, random_ifs, random_sps = experiment(upper_bound)
        with open("credit_results.txt", "a") as f:
            f.write("L = {}\n".format(upper_bound))
            f.write("d acc = {:.3f} + {:.3f}, if = {:.3f} + {:.3f}, sp = {:.3f} + {:.3f}\n".format(np.mean(base_accs), np.std(base_accs),
                                                                         np.mean(base_ifs), np.std(base_ifs),
                                                                         np.mean(base_sps), np.std(base_sps)))
            f.write("r acc = {:.3f} + {:.3f}, if = {:.3f} + {:.3f}, sp = {:.3f} + {:.3f}\n".format(np.mean(random_accs), np.std(random_accs),
                                                                         np.mean(random_ifs), np.std(random_ifs),
                                                                         np.mean(random_sps), np.std(random_sps)))

if __name__ == "__main__":
    main()