import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from itertools import combinations
from scipy.optimize import linprog, minimize
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, f1_score
from sklearn.model_selection import train_test_split

def if_metric(x, y_pred, tol=1e-6):
    n = len(x)
    max_ratio = 0
    total_ratio = 0
    count = 0

    for i in range(n):
        for j in range(i + 1, n):
            dist = np.linalg.norm(x[i] - x[j])
            if dist < tol:
                continue  # avoid dividing by near-zero
            pred_diff = np.abs(y_pred[i] - y_pred[j])
            ratio = pred_diff / dist
            max_ratio = max(max_ratio, ratio)
            total_ratio += ratio
            count += 1

    avg_ratio = total_ratio / count if count > 0 else float('nan')
    return max_ratio, avg_ratio

def F_0(score, y, t):
    return np.sum((score <= t) & (y == 0)) / score.shape[0]

def F_1(score, y, t):
    return np.sum((score <= t) & (y == 1)) / score.shape[0]

def F(score, t):
    return np.sum(score <= t) / score.shape[0]

def expected_outcome(p, s):
    k = int((s - 1.0) / 0.5)
    return sum(p[0:k] * 0.5)+ p[k] * (s - 1.0 - k * 0.5)

def cost(t, s):
    if s >= t:
        return 0
    elif s < t - 1.0:
        return 0
    else:
        return (t - s) ** 2

def expected_cost(p, s):
    if s <= 0.5:
        return (1 / 3) * p[1] * (1.0 - (1.0 - s) ** 3)
    elif 0.5 < s <= 1.0:
        return (1 / 3) * p[1] * ((1.5 - s) ** 3 - (1.0 - s) ** 3) + (1 / 3) * p[2] * (1.0 - (1.5 - s) ** 3)
    
    k = int((s - 1.0) / 0.5)
    s_k_1 = 1.0 + k * 0.5
    e_k_1 = 1.0 + (k + 1) * 0.5
    s_k_2 = 1.0 + (k + 2) * 0.5
    e_k_2 = 1.0 + (k + 2) * 0.5
    
    return (1 / 3) * p[k] * ((e_k_1 - s) ** 3) + (1 / 3) * p[k + 1] * ((s_k_2 - s) ** 3 - (e_k_1 - s) ** 3) + \
        (1 / 3) * p[k + 2] * (1.0 - (e_k_2 - s) ** 3)

def experiment(upper_bound, gp_bound):
    data = pd.read_csv("./data/scores.csv")
    s = np.array(data["Score"])
    y = np.array(data["Default"])
    a = np.array(data["Race"])
    
    mask = (a == "Black") | (a == "Non-Hispanic white")
    s = s[mask]
    y = y[mask]
    a = a[mask]
    
    base_accs = []
    base_ifs = []
    base_sps = []
    random_accs = []
    random_ifs = []
    random_sps = []
    for seed in range(42, 47):
        np.random.seed(seed)
        
        train_y, valid_y, train_s, valid_s, train_a, valid_a = train_test_split(y, s, a, test_size=0.4, random_state=seed)
        valid_y, test_y, valid_s, test_s, valid_a, test_a = train_test_split(valid_y, valid_s, valid_a, test_size=0.5, random_state=seed)
        
        train_score = train_s
        valid_score = valid_s
        test_score = test_s
        
        races = list(np.unique(a))
        
        print("### deterministic ###")
        ### find best deterministic model
        train_s_group_1 = train_s[train_a == "Black"]
        train_s_group_2 = train_s[train_a == "Non-Hispanic white"]
        train_y_group_1 = train_y[train_a == "Black"]
        train_y_group_2 = train_y[train_a == "Non-Hispanic white"]
        
        best_acc = 0.0
        for t_1 in np.arange(1.0, 101.0, 0.5):
            for t_2 in np.arange(1.0, 101.0, 0.5):
                train_pred_group_1 = (train_s_group_1 + 1.0 >= t_1).astype(int)
                train_pred_group_2 = (train_s_group_2 + 1.0 >= t_2).astype(int)
                train_positive_rate_1 = np.sum(train_pred_group_1) / len(train_pred_group_1)
                train_positive_rate_2 = np.sum(train_pred_group_2) / len(train_pred_group_2)
                if abs(train_positive_rate_1 - train_positive_rate_2) <= 0.10:
                    acc = f1_score(np.concatenate([train_y_group_1, train_y_group_2]), np.concatenate([train_pred_group_1, train_pred_group_2]), average="macro")
                    if acc > best_acc:
                        best_acc = acc 
                        best_ts = [t_1, t_2]   
        t_1, t_2 = best_ts[0], best_ts[1]
        
        ### evaluate the deterministic model
        test_s_group_1 = test_s[test_a == "Black"]
        test_s_group_2 = test_s[test_a == "Non-Hispanic white"]
        test_y_group_1 = test_y[test_a == "Black"]
        test_y_group_2 = test_y[test_a == "Non-Hispanic white"]
        test_a_group_1 = np.zeros_like(test_a[test_a == "Black"]).reshape(-1)
        test_a_group_2 = np.ones_like(test_a[test_a == "Non-Hispanic white"]).reshape(-1)
        
        test_pred_group_1 = (test_s_group_1 + 1.0 >= t_1).astype(int)
        test_pred_group_2 = (test_s_group_2 + 1.0 >= t_2).astype(int)
        
        test_acc = f1_score(np.concatenate([test_y_group_1, test_y_group_2]), np.concatenate([test_pred_group_1, test_pred_group_2]), average="macro")
        print("test_acc = {}".format(test_acc))
        base_accs.append(test_acc)
        
        test_features = np.concatenate([
            np.concatenate([test_a_group_1, test_a_group_2], axis=0).reshape(-1, 1),
            np.concatenate([test_s_group_1, test_s_group_2], axis=0).reshape(-1, 1)
        ], axis=-1)
        #if_ratio, __ = if_metric(test_features, np.concatenate([test_pred_group_1, test_pred_group_2]))
        #print("if ratio = {}".format(if_ratio))
        #base_ifs.append(if_ratio)
        test_cost_group_1 = [cost(t_1, s) for s in test_s_group_1]
        test_cost_group_2 = [cost(t_2, s) for s in test_s_group_2]
        if_ratio, __ = if_metric(test_features, np.concatenate([test_cost_group_1, test_a_group_2]))
        print("if ratio = {}".format(if_ratio))
        base_ifs.append(if_ratio)
        
        test_positive_rate_1 = np.sum(test_pred_group_1) / len(test_pred_group_1)
        test_positive_rate_2 = np.sum(test_pred_group_2) / len(test_pred_group_2)
        diff = abs(test_positive_rate_1 - test_positive_rate_2)
        print("group fairness = {}".format(diff))
        base_sps.append(diff)
        
        print("### random ###")
        ### find random classifier
        p = [np.zeros(200) for race in np.unique(test_a)]
        A = [np.zeros(200) for race in np.unique(test_a)]
        B = [np.zeros(200) for race in np.unique(test_a)]
        for k in range(200):
            s_k = 1.0 + k * 0.5
            e_k = 1.0 + (k + 1) * 0.5
            
            for i, race in enumerate(np.unique(test_a)):
                A[i][k] = 0.5 * (F_1(train_s[train_a == race], train_y[train_a == race], s_k - 1.0) - F_1(train_s[train_a == race], train_y[train_a == race], 0.0) + \
                    F_0(train_s[train_a == race], train_y[train_a == race], 100.0) - F_0(train_s[train_a == race], train_y[train_a == race], s_k - 1.0) + \
                        F_1(train_s[train_a == race], train_y[train_a == race], e_k - 1.0) - F_1(train_s[train_a == race], train_y[train_a == race], 0.0) + \
                            F_0(train_s[train_a == race], train_y[train_a == race], 100.0) - F_0(train_s[train_a == race], train_y[train_a == race], e_k - 1.0)) * 0.5
                #B[i][k] = 0.5 * (F(train_s[train_a == race], s_k - 1.0) - F(train_s[train_a == race], 0.0) + \
                #    F(train_s[train_a == race], 100.0) - F(train_s[train_a == race], s_k - 1.0) + \
                #        F(train_s[train_a == race], e_k - 1.0) - F(train_s[train_a == race], 0.0) + \
                #            F(train_s[train_a == race], 100.0) - F(train_s[train_a == race], e_k - 1.0))
                B[i][k] = 0.5 * (F(train_s[train_a == race], 100.0) - F(train_s[train_a == race], s_k - 1.0) + \
                    F(train_s[train_a == race], 100.0) - F(train_s[train_a == race], e_k - 1.0)) * 0.5
        
        num_bins = 200
        L = upper_bound
        delta = gp_bound
        groups = np.unique(test_a)
        num_groups = len(groups)
        c = np.concatenate([(1 / num_groups) * A[i] for i in range(num_groups)])
        A_eq = []
        b_eq = []
        for i in range(num_groups):
            row = np.zeros(num_groups * num_bins)
            row[i * num_bins : (i + 1) * num_bins] = 1
            A_eq.append(row)
            b_eq.append(2)
        
        A_ub = []
        b_ub = []
        for i, j in combinations(range(num_groups), 2):
            row = np.zeros(num_groups * num_bins)
            row[i * num_bins : (i + 1) * num_bins] = B[i]
            row[j * num_bins : (j + 1) * num_bins] = -B[j]
            A_ub.append(row)
            b_ub.append(delta)
            A_ub.append(-row)
            b_ub.append(delta)
        
        A_eq = np.array(A_eq)
        b_eq = np.array(b_eq)
        A_ub = np.array(A_ub)
        b_ub = np.array(b_ub)
        bounds = [(0, L) for _ in range(num_groups * num_bins)]
        res = linprog(c=c, A_eq=A_eq, b_eq=b_eq, A_ub=A_ub, b_ub=b_ub, bounds=bounds, method='highs')
        
        for i in range(num_groups):
            p[i] = res.x[i * num_bins : (i + 1) * num_bins]
        
        
        test_score_groups = []
        test_y_groups = []
        test_pred_groups = []
        test_a_groups = []
        for i, race in enumerate(np.unique(test_a)):
            test_score_groups.append(test_score[test_a == race])
            test_y_groups.append(test_y[test_a == race])
            test_pred_groups.append([])
            if race == "Black":
                test_a_groups.append(np.zeros_like(test_a[test_a == race]).reshape(-1))
            else:
                test_a_groups.append(np.ones_like(test_a[test_a == race]).reshape(-1))
            
        
        for i, race in enumerate(np.unique(test_a)):
            bin_width = 0.5
            bin_edges = np.linspace(0.0, 101, len(p[i]) + 1) 
            prob_mass = p[i] * bin_width
            prob_mass /= prob_mass.sum()

            bin_indices = np.random.choice(len(p[i]), size=len(test_score_groups[i]), p=prob_mass)
            lefts = bin_edges[bin_indices]
            samples = lefts + np.random.rand(len(test_score_groups[i])) * bin_width
            
            test_pred_groups[i] = ((test_score_groups[i] + 1.0) >= samples).astype(int)
        
        test_pred = np.concatenate([test_pred_group for test_pred_group in test_pred_groups], axis=-1)
        test_y = np.concatenate([test_y_group for test_y_group in test_y_groups], axis=-1)
        acc = f1_score(test_pred, test_y, average="macro")
        print("test acc = {}".format(acc))
        random_accs.append(acc)
        
        #test_pred_expected_groups = []
        #for i, race in enumerate(np.unique(test_a)):
        #    test_pred_expected_groups.append([expected_outcome(p[i], s + 1.0) for s in test_score_groups[i]])
        #test_pred_expected = np.concatenate([test_pred_expected_group for test_pred_expected_group in test_pred_expected_groups], axis=-1)
        #test_s = np.concatenate([test_score_group for test_score_group in test_score_groups], axis=0)
        #test_a = np.concatenate([test_a_group for test_a_group in test_a_groups], axis=0)
        #if_ratio, __ = if_metric(np.concatenate([test_a.reshape(-1, 1), test_s.reshape(-1, 1)], axis=-1), test_pred_expected)
        #print("if ratio = {}".format(if_ratio))
        #random_ifs.append(if_ratio)
        test_cost_expected_groups = []
        for i, race in enumerate(np.unique(test_a)):
            test_cost_expected_groups.append([expected_cost(p[i], s) for s in test_score_groups[i]])
        test_cost_expected = np.concatenate([test_cost_expected_group for test_cost_expected_group in test_cost_expected_groups], axis=-1)
        test_s = np.concatenate([test_score_group for test_score_group in test_score_groups], axis=0)
        test_a = np.concatenate([test_a_group for test_a_group in test_a_groups], axis=0)
        if_ratio, __ = if_metric(np.concatenate([test_a.reshape(-1, 1), test_s.reshape(-1, 1)], axis=-1), test_cost_expected)
        print("if ratio = {}".format(if_ratio))
        random_ifs.append(if_ratio)
        
        test_positive_rates = [np.sum(test_pred_group) / len(test_pred_group) for test_pred_group in test_pred_groups]
        diff = abs(test_positive_rates[0] - test_positive_rates[1])
        print("group fairness = {}".format(diff))
        random_sps.append(diff)
    
    return base_accs, base_ifs, base_sps, random_accs, random_ifs, random_sps

def main():
    upper_bounds = [1, 0.5, 0.25, 0.1, 0.05]
    gp_bounds = [0.1, 0.08, 0.06, 0.04, 0.02]
    
    with open("credit_group_results.txt", "w") as f:
        pass
    
    for gp_bound in gp_bounds:
        base_accs, base_ifs, base_sps, random_accs, random_ifs, random_sps = experiment(1, gp_bound)
        with open("credit_group_results.txt", "a") as f:
            f.write("L = {}\n".format(gp_bound))
            f.write("d acc = {:.3f} + {:.3f}, if = {:.3f} + {:.3f}, sp = {:.3f} + {:.3f}\n".format(np.mean(base_accs), np.std(base_accs),
                                                                         np.mean(base_ifs), np.std(base_ifs),
                                                                         np.mean(base_sps), np.std(base_sps)))
            f.write("r acc = {:.3f} + {:.3f}, if = {:.3f} + {:.3f}, sp = {:.3f} + {:.3f}\n".format(np.mean(random_accs), np.std(random_accs),
                                                                         np.mean(random_ifs), np.std(random_ifs),
                                                                         np.mean(random_sps), np.std(random_sps)))


if __name__ == "__main__":
    main()