import numpy as np
import pandas as pd

from scipy.optimize import linprog
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, f1_score
from sklearn.model_selection import train_test_split

def if_metric(x, y_pred, tol=1e-6):
    n = len(x)
    max_ratio = 0
    total_ratio = 0
    count = 0

    for i in range(n):
        for j in range(i + 1, n):
            dist = np.linalg.norm(x[i] - x[j])
            if dist < tol:
                continue  # avoid dividing by near-zero
            pred_diff = np.abs(y_pred[i] - y_pred[j])
            ratio = pred_diff / dist
            max_ratio = max(max_ratio, ratio)
            total_ratio += ratio
            count += 1

    avg_ratio = total_ratio / count if count > 0 else float('nan')
    return max_ratio, avg_ratio

def F_0(score, y, t):
    return np.sum((score <= t) & (y == 0)) / score.shape[0]

def F_1(score, y, t):
    return np.sum((score <= t) & (y == 1)) / score.shape[0]

def expected_outcome(p, s):
    k = int((s + 3.9) / 0.1)
    return sum(p[0:k] * 0.1)+ p[k] * (s + 3.9 - k * 0.1)

def cost(t, s):
    if s >= t:
        return 0
    elif s < t - 0.1:
        return 0
    else:
        return 100 * (t - s) ** 2
    
def expected_cost(p, s):
    if s <= -3.9:
        return (100 / 3) * p[0] * (0.1 ** 3 - (-3.9 - s) ** 3)
    else:
        k = int((s + 3.9) / 0.1)
        e_k = -3.9 + (k + 1) * 0.1
        return (100 / 3) * p[k] * ((e_k - s) ** 3) + (100 / 3) * p[k + 1] * (0.1 ** 3 - (e_k - s) ** 3)

def experiment(upper_bound):
    data = pd.read_csv("./law/law_data.csv")
    x = np.array(data[["decile1b", "decile3", "lsat", "ugpa", "zfygpa", "fulltime", "fam_inc", "male", "racetxt", "tier"]])
    s = np.array(data["zgpa"])
    y = np.array(data["pass_bar"])
    a = np.array(data["racetxt"])
    
    base_accs = []
    base_ifs = []
    base_sps = []
    random_accs = []
    random_ifs = []
    random_sps = []
    for seed in range(42, 47):
        np.random.seed(seed)
        
        train_x, valid_x, train_y, valid_y, train_s, valid_s, train_a, valid_a = train_test_split(x, y, s, a, test_size=0.4, random_state=seed)
        valid_x, test_x, valid_y, test_y, valid_s, test_s, valid_a, test_a = train_test_split(valid_x, valid_y, valid_s, valid_a, test_size=0.5, random_state=seed)
        
        model = LinearRegression()
        model.fit(train_x, train_s)
        train_score = model.predict(train_x)
        valid_score = model.predict(valid_x)
        test_score = model.predict(test_x)
        
        ### find best deterministic model
        best_f1 = 0
        for t in np.arange(-3.9, 4.1, 0.1):
            train_valid_s = np.concatenate([train_s, valid_s], axis=-1)
            train_valid_y = np.concatenate([train_y, valid_y], axis=-1)
            train_valid_pred = ((train_valid_s + 0.1) >= t).astype(int)
            acc = f1_score(train_valid_y, train_valid_pred, average="macro")
            if acc > best_f1:
                best_f1 = acc
                best_t = t
        
        ### evaluate the deterministic model
        test_pred = ((test_score + 0.1) >= best_t).astype(int)
        test_acc = f1_score(test_y, test_pred, average="macro")
        print("test acc = {}".format(test_acc))
        base_accs.append(test_acc)
        
        #test_cost = np.zeros_like(test_score)
        #test_cost = test_score[test_score < best_t]
        #test_cost = best_t - test_score
        #test_cost = test_cost[test_cost <= 0.1]
        #test_cost = 100 * (test_cost ** 2)
        #if_ratio, __ = if_metric(test_x, test_cost)
        #print("if ratio = {}".format(if_ratio))
        #base_ifs.append(if_ratio)
        test_cost = [cost(best_t, s) for s in test_score]
        if_ratio, __ = if_metric(test_x, test_cost)
        print("if ratio = {}".format(if_ratio))
        base_ifs.append(if_ratio)
        #if_ratio, __ = if_metric(test_x, test_pred)
        #print("if ratio = {}".format(if_ratio))
        #base_ifs.append(if_ratio)
        
        #test_pred_group_0 = test_pred[(test_a == 0) & (test_y == 1)]
        #test_pred_group_1 = test_pred[(test_a == 1) & (test_y == 1)]
        #diff_1 = abs(np.sum(test_pred_group_0) / len(test_pred_group_0) - np.sum(test_pred_group_1) / len(test_pred_group_1))
        #test_pred_group_0 = test_pred[(test_a == 0) & (test_y == 0)]
        #test_pred_group_1 = test_pred[(test_a == 1) & (test_y == 0)]
        #diff_2 = abs(np.sum(test_pred_group_0) / len(test_pred_group_0) - np.sum(test_pred_group_1) / len(test_pred_group_1))
        #print("group fairness = {}".format(diff_2))
        #base_sps.append(max(diff_1, diff_2))
        test_pred_group_0 = test_pred[test_a == 0]
        test_pred_group_1 = test_pred[test_a == 1]
        diff = abs(np.sum(test_pred_group_0) / len(test_pred_group_0) - np.sum(test_pred_group_1) / len(test_pred_group_1))
        print("group fairness = {}".format(diff))
        base_sps.append(diff)
        
        ### find random classifier
        p = np.zeros(80)
        A = np.zeros(80)
        for k in range(80):
            s_k = -3.9 + k * 0.1
            e_k = -3.9 + (k + 1) * 0.1
            
            A[k] = 0.5 * (0.1 * F_1(train_s, train_y, s_k - 0.1) - 0.1 * F_1(train_s, train_y, -4.0) + \
                0.9 * F_0(train_s, train_y, 4.0) - 0.9 * F_0(train_s, train_y, s_k - 0.1) + \
                    0.1 * F_1(train_s, train_y, e_k - 0.1) - 0.1 * F_1(train_s, train_y, -4.0) + \
                        0.9 * F_0(train_s, train_y, 4.0) - 0.9 * F_0(train_s, train_y, e_k - 0.1))
    

        c = A
        A_eq = np.ones((1, 80))
        b_eq = [10]
        L = upper_bound
        bounds = [(0, L) for _ in range(80)]
        res = linprog(c, A_eq=A_eq, b_eq=b_eq, bounds=bounds, method='highs')
        p = res.x
        print("p = {}".format(p))
        
        bin_width = 0.1
        bin_edges = np.linspace(-3.9, 4.1, len(p) + 1) 
        prob_mass = p * bin_width
        prob_mass /= prob_mass.sum()

        bin_indices = np.random.choice(len(p), size=len(test_score), p=prob_mass)
        lefts = bin_edges[bin_indices]
        samples = lefts + np.random.rand(len(test_score)) * bin_width
        
        test_pred = ((test_score + 0.1) > samples).astype(int)
        acc = f1_score(test_y, test_pred, average="macro")
        print("test acc = {}".format(acc))
        random_accs.append(acc)
        
        #test_pred_group_0 = test_pred[(test_a == 0) & (test_y == 1)]
        #test_pred_group_1 = test_pred[(test_a == 1) & (test_y == 1)]
        #diff_1 = abs(np.sum(test_pred_group_0) / len(test_pred_group_0) - np.sum(test_pred_group_1) / len(test_pred_group_1))
        #print("group fairness = {}".format(diff_2))
        #test_pred_group_0 = test_pred[(test_a == 0) & (test_y == 0)]
        #test_pred_group_1 = test_pred[(test_a == 1) & (test_y == 0)]
        #diff_2 = abs(np.sum(test_pred_group_0) / len(test_pred_group_0) - np.sum(test_pred_group_1) / len(test_pred_group_1))
        #random_sps.append(max(diff_1, diff_2))
        test_pred_group_0 = test_pred[test_a == 0]
        test_pred_group_1 = test_pred[test_a == 1]
        diff = abs(np.sum(test_pred_group_0) / len(test_pred_group_0) - np.sum(test_pred_group_1) / len(test_pred_group_1))
        print("group fairness = {}".format(diff))
        random_sps.append(diff)
        
        #test_pred_expected = [expected_outcome(p, s + 0.1) for s in test_score]
        #if_ratio, __ = if_metric(test_x, test_pred_expected)
        #print("if ratio = {}".format(if_ratio))
        #random_ifs.append(if_ratio)
        test_cost_expected = [expected_cost(p, s) for s in test_score]
        if_ratio, __ = if_metric(test_x, test_cost_expected)
        print("if ratio = {}".format(if_ratio))
        random_ifs.append(if_ratio)
    
    return base_accs, base_ifs, base_sps, random_accs, random_ifs, random_sps

def main():
    upper_bounds = [1, 0.8, 0.4, 0.3]
    
    with open("law_results.txt", "w") as f:
        pass
    
    for upper_bound in upper_bounds:
        base_accs, base_ifs, base_sps, random_accs, random_ifs, random_sps = experiment(upper_bound)
        with open("law_results.txt", "a") as f:
            f.write("L = {}\n".format(upper_bound))
            f.write("d acc = {:.3f} + {:.3f}, if = {:.3f} + {:.3f}, sp = {:.3f} + {:.3f}\n".format(np.mean(base_accs), np.std(base_accs),
                                                                         np.mean(base_ifs), np.std(base_ifs),
                                                                         np.mean(base_sps), np.std(base_sps)))
            f.write("r acc = {:.3f} + {:.3f}, if = {:.3f} + {:.3f}, sp = {:.3f} + {:.3f}\n".format(np.mean(random_accs), np.std(random_accs),
                                                                         np.mean(random_ifs), np.std(random_ifs),
                                                                         np.mean(random_sps), np.std(random_sps)))

if __name__ == "__main__":
    main()