import numpy as np
import pandas as pd

from scipy.optimize import linprog
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, f1_score
from sklearn.model_selection import train_test_split

def if_metric(x, y_pred, tol=1e-6):
    n = len(x)
    max_ratio = 0
    total_ratio = 0
    count = 0

    for i in range(n):
        for j in range(i + 1, n):
            dist = np.linalg.norm(x[i] - x[j])
            if dist < tol:
                continue  # avoid dividing by near-zero
            pred_diff = np.abs(y_pred[i] - y_pred[j])
            ratio = pred_diff / dist
            max_ratio = max(max_ratio, ratio)
            total_ratio += ratio
            count += 1

    avg_ratio = total_ratio / count if count > 0 else float('nan')
    return max_ratio, avg_ratio

def F_0(score, y, t):
    return np.sum((score <= t) & (y == 0)) / score.shape[0]

def F_1(score, y, t):
    return np.sum((score <= t) & (y == 1)) / score.shape[0]

def F(score, t):
    return np.sum(score <= t) / score.shape[0]

def expected_outcome(p, s):
    k = int((s + 3.9) / 0.1)
    return sum(p[0:k] * 0.1)+ p[k] * (s + 3.9 - k * 0.1)

def cost(t, s):
    if s >= t:
        return 0
    elif s < t - 0.1:
        return 0
    else:
        return 100 * (t - s) ** 2
    
def expected_cost(p, s):
    if s <= -3.9:
        return (100 / 3) * p[0] * (0.1 ** 3 - (-3.9 - s) ** 3)
    else:
        k = int((s + 3.9) / 0.1)
        e_k = -3.9 + (k + 1) * 0.1
        return (100 / 3) * p[k] * ((e_k - s) ** 3) + (100 / 3) * p[k + 1] * (0.1 ** 3 - (e_k - s) ** 3)

def experiment(upper_bound, gp_bound):
    data = pd.read_csv("./law/law_data.csv")
    x = np.array(data[["decile1b", "decile3", "lsat", "ugpa", "zfygpa", "fulltime", "fam_inc", "male", "racetxt", "tier"]])
    s = np.array(data["zgpa"])
    y = np.array(data["pass_bar"])
    a = np.array(data["racetxt"])
    
    base_accs = []
    base_ifs = []
    base_sps = []
    random_accs = []
    random_ifs = []
    random_sps = []
    for seed in range(42, 47):
        np.random.seed(seed)
        
        train_x, valid_x, train_y, valid_y, train_s, valid_s, train_a, valid_a = train_test_split(x, y, s, a, test_size=0.4, random_state=seed)
        valid_x, test_x, valid_y, test_y, valid_s, test_s, valid_a, test_a = train_test_split(valid_x, valid_y, valid_s, valid_a, test_size=0.5, random_state=seed)
        
        model = LinearRegression()
        model.fit(train_x, train_s)
        train_score = model.predict(train_x)
        valid_score = model.predict(valid_x)
        test_score = model.predict(test_x)
        
        print("### deterministic ###")
        ### find best deterministic model
        train_s_group_1 = train_s[train_a == 0]
        train_s_group_2 = train_s[train_a == 1]
        train_y_group_1 = train_y[train_a == 0]
        train_y_group_2 = train_y[train_a == 1]
        
        best_acc = 0.0
        for t_1 in np.arange(-3.9, 4.1, 0.1):
            for t_2 in np.arange(-3.9, 4.1, 0.1):
                train_pred_group_1 = (train_s_group_1 >= t_1).astype(int)
                train_pred_group_2 = (train_s_group_2 >= t_2).astype(int)
                train_positive_rate_1 = np.sum(train_pred_group_1) / len(train_pred_group_1)
                train_positive_rate_2 = np.sum(train_pred_group_2) / len(train_pred_group_2)
                #if abs(train_positive_rate_1 - train_positive_rate_2) <= 0.20:
                if abs(train_positive_rate_1 - train_positive_rate_2) <= 0.15:
                    acc = f1_score(np.concatenate([train_y_group_1, train_y_group_2]), np.concatenate([train_pred_group_1, train_pred_group_2]), average="macro")
                    if acc > best_acc:
                        best_acc = acc 
                        best_ts = [t_1, t_2]   
        t_1, t_2 = best_ts[0], best_ts[1]
        
        ### evaluate the deterministic model
        test_s_group_1 = test_score[test_a == 0]
        test_s_group_2 = test_score[test_a == 1]
        test_y_group_1 = test_y[test_a == 0]
        test_y_group_2 = test_y[test_a == 1]
        test_x_group_1 = test_x[test_a == 0]
        test_x_group_2 = test_x[test_a == 1]
        
        test_pred_group_1 = (test_s_group_1 + 0.1 >= t_1).astype(int)
        test_pred_group_2 = (test_s_group_2 + 0.1 >= t_2).astype(int)
        
        test_acc = f1_score(np.concatenate([test_y_group_1, test_y_group_2]), np.concatenate([test_pred_group_1, test_pred_group_2]), average="macro")
        print("test_acc = {}".format(test_acc))
        base_accs.append(test_acc)
        
        if_ratio, __ = if_metric(np.concatenate([test_x_group_1, test_x_group_2], axis=0), np.concatenate([test_pred_group_1, test_pred_group_2]))
        print("if ratio = {}".format(if_ratio))
        base_ifs.append(if_ratio)
        #test_cost_group_1 = np.array([cost(t_1, s) for s in test_s_group_1])
        #test_cost_group_2 = np.array([cost(t_2, s) for s in test_s_group_2])
        #if_ratio, __ = if_metric(np.concatenate([test_x_group_1, test_x_group_2], axis=0), np.concatenate([test_cost_group_1, test_cost_group_2]))
        #print("if ratio = {}".format(if_ratio))
        #base_ifs.append(if_ratio)
        
        
        diff = abs(np.sum(test_pred_group_1) / len(test_pred_group_1) - np.sum(test_pred_group_2) / len(test_pred_group_2))
        print("group fairness = {}".format(diff))
        base_sps.append(diff)
        
        print("### random ###")
        ### find random classifier
        p0 = np.zeros(80)
        A0 = np.zeros(80)
        B0 = np.zeros(80)
        p1 = np.zeros(80)
        A1 = np.zeros(80)
        B1 = np.zeros(80)
        for k in range(80):
            s_k = -3.9 + k * 0.1
            e_k = -3.9 + (k + 1) * 0.1
            
            A0[k] = 0.5 * (0.1 * F_1(train_s[train_a == 0], train_y[train_a == 0], s_k - 0.1) - 0.1 * F_1(train_s[train_a == 0], train_y[train_a == 0], -4.0) + \
                0.9 * F_0(train_s[train_a == 0], train_y[train_a == 0], 4.0) - 0.9 * F_0(train_s[train_a == 0], train_y[train_a == 0], s_k - 0.1) + \
                    0.1 * F_1(train_s[train_a == 0], train_y[train_a == 0], e_k - 0.1) - 0.1 * F_1(train_s[train_a == 0], train_y[train_a == 0], -4.0) + \
                        0.9 * F_0(train_s[train_a == 0], train_y[train_a == 0], 4.0) - 0.9 * F_0(train_s[train_a == 0], train_y[train_a == 0], e_k - 0.1)) * 0.1
            #B0[k] = 0.5 * (F(train_s[train_a == 0], s_k - 0.1) - F(train_s[train_a == 0], -4.0) + \
            #    F(train_s[train_a == 0], 4.0) - F(train_s[train_a == 0], s_k - 0.1) + \
            #        F(train_s[train_a == 0], e_k - 0.1) - F(train_s[train_a == 0], -4.0) + \
            #            F(train_s[train_a == 0], 4.0) - F(train_s[train_a == 0], e_k - 0.1))
            B0[k] = 0.5 * (F(train_s[train_a == 0], 4.0) - F(train_s[train_a == 0], s_k - 0.1) + \
                F(train_s[train_a == 0], 4.0) - F(train_s[train_a == 0], e_k - 0.1)) * 0.1
            A1[k] = 0.5 * (0.1 * F_1(train_s[train_a == 1], train_y[train_a == 1], s_k - 0.1) - 0.1 * F_1(train_s[train_a == 1], train_y[train_a == 1], -4.0) + \
                0.9 * F_0(train_s[train_a == 1], train_y[train_a == 1], 4.0) - 0.9 * F_0(train_s[train_a == 1], train_y[train_a == 1], s_k - 0.1) + \
                    0.1 * F_1(train_s[train_a == 1], train_y[train_a == 1], e_k - 0.1) - 0.1 * F_1(train_s[train_a == 1], train_y[train_a == 1], -4.0) + \
                        0.9 * F_0(train_s[train_a == 1], train_y[train_a == 1], 4.0) - 0.9 * F_0(train_s[train_a == 1], train_y[train_a == 1], e_k - 0.1)) * 0.1
            #B1[k] = 0.5 * (F(train_s[train_a == 1], s_k - 0.1) - F(train_s[train_a == 1], -4.0) + \
            #    F(train_s[train_a == 1], 4.0) - F(train_s[train_a == 1], s_k - 0.1) + \
            #        F(train_s[train_a == 1], e_k - 0.1) - F(train_s[train_a == 1], -4.0) + \
            #            F(train_s[train_a == 1], 4.0) - F(train_s[train_a == 1], e_k - 0.1))
            B1[k] = 0.5 * (F(train_s[train_a == 1], 4.0) - F(train_s[train_a == 1], s_k - 0.1) + \
                F(train_s[train_a == 1], 4.0) - F(train_s[train_a == 1], e_k - 0.1)) * 0.1

        n = len(A0)
        c = np.concatenate([0.4 * A0, 0.6 * A1])
        A_eq1 = np.concatenate([np.full(n, 0.1), np.zeros(n)])
        A_eq2 = np.concatenate([np.zeros(n), np.full(n, 0.1)])
        A_eq = np.vstack([A_eq1, A_eq2])
        b_eq = np.array([1.0, 1.0])

        epsilon = gp_bound
        A_ub1 = np.concatenate([B0, -B1])
        A_ub2 = -A_ub1
        A_ub = np.vstack([A_ub1, A_ub2])
        b_ub = np.array([epsilon, epsilon])
        
        L = upper_bound
        bounds = [(0, L) for _ in range(2 * n)]

        res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bounds, method='highs')

        p = res.x
        p0 = p[:n]
        p1 = p[n:]
        
        #print(p0)
        #print(p1)
        
        
        test_score_group_0 = test_score[test_a == 0]
        test_score_group_1 = test_score[test_a == 1]
        test_y_group_0 = test_y[test_a == 0]
        test_y_group_1 = test_y[test_a == 1]
        test_x_group_0 = test_x[test_a == 0]
        test_x_group_1 = test_x[test_a == 1]
         
        bin_width = 0.1
        bin_edges = np.linspace(-3.9, 4.1, len(p) + 1) 
        prob_mass = p0 * bin_width
        prob_mass /= prob_mass.sum()

        bin_indices = np.random.choice(len(p0), size=len(test_score_group_0), p=prob_mass)
        lefts = bin_edges[bin_indices]
        samples = lefts + np.random.rand(len(test_score_group_0)) * bin_width
        
        test_pred_group_0 = ((test_score_group_0 + 0.1) >= samples).astype(int)
        
        bin_width = 0.1
        bin_edges = np.linspace(-3.9, 4.1, len(p) + 1) 
        prob_mass = p1 * bin_width
        prob_mass /= prob_mass.sum()

        bin_indices = np.random.choice(len(p1), size=len(test_score_group_1), p=prob_mass)
        lefts = bin_edges[bin_indices]
        samples = lefts + np.random.rand(len(test_score_group_1)) * bin_width
        
        test_pred_group_1 = ((test_score_group_1 + 0.1) >= samples).astype(int)
        
        test_pred = np.concatenate([test_pred_group_0, test_pred_group_1], axis=-1)
        test_y = np.concatenate([test_y_group_0, test_y_group_1], axis=-1)
        acc = f1_score(test_pred, test_y, average="macro")
        print("test acc = {}".format(acc))
        random_accs.append(acc)
        
        #test_pred_expected_group_0 = [expected_outcome(p0, s + 0.1) for s in test_score_group_0]
        #test_pred_expected_group_1 = [expected_outcome(p1, s + 0.1) for s in test_score_group_1]
        #test_pred_expected = np.concatenate([test_pred_expected_group_0, test_pred_expected_group_1], axis=-1)
        #test_x = np.concatenate([test_x_group_0, test_x_group_1], axis=0)
        #if_ratio, __ = if_metric(test_x, test_pred_expected)
        #print("if ratio = {}".format(if_ratio))
        #random_ifs.append(if_ratio)
        
        test_cost_expected_group_0 = [expected_cost(p0, s) for s in test_score_group_0]
        test_cost_expected_group_1 = [expected_cost(p1, s) for s in test_score_group_1]
        test_cost_expected = np.concatenate([test_cost_expected_group_0, test_cost_expected_group_1], axis=-1)
        test_x = np.concatenate([test_x_group_0, test_x_group_1], axis=0)
        if_ratio, __ = if_metric(test_x, test_cost_expected)
        print("if ratio = {}".format(if_ratio))
        random_ifs.append(if_ratio)
        
        diff = abs(np.sum(test_pred_group_0) / len(test_pred_group_0) - np.sum(test_pred_group_1) / len(test_pred_group_1))
        print("group fairness = {}".format(diff))
        random_sps.append(diff)
        
    
    return base_accs, base_ifs, base_sps, random_accs, random_ifs, random_sps

def main():
    upper_bounds = [1, 0.8, 0.4, 0.3, 0.2]
    gp_bounds = [0.1, 0.08, 0.06, 0.04, 0.02]
    
    with open("law_group_results.txt", "w") as f:
        pass
    
    for gp_bound in gp_bounds:
        base_accs, base_ifs, base_sps, random_accs, random_ifs, random_sps = experiment(1, gp_bound)
        with open("law_group_results.txt", "a") as f:
            f.write("L = {}\n".format(gp_bound))
            f.write("d acc = {:.3f} + {:.3f}, if = {:.3f} + {:.3f}, sp = {:.3f} + {:.3f}\n".format(np.mean(base_accs), np.std(base_accs),
                                                                         np.mean(base_ifs), np.std(base_ifs),
                                                                         np.mean(base_sps), np.std(base_sps)))
            f.write("r acc = {:.3f} + {:.3f}, if = {:.3f} + {:.3f}, sp = {:.3f} + {:.3f}\n".format(np.mean(random_accs), np.std(random_accs),
                                                                         np.mean(random_ifs), np.std(random_ifs),
                                                                         np.mean(random_sps), np.std(random_sps)))


if __name__ == "__main__":
    main()