#!/usr/bin/env python3

import numpy as np
import math
import random
import csv
import pickle
import pandas as pd

'''
cgoo_runs.py:
Contains code for running fair regression differentially private algorithms.

1)  gen_data_lin_norm: generates data from a normally distributed independent variable.
2)  gen_data_lin_unif: generates data from a uniformly distributed independent variable.
3)  gen_data_lin_exp: generates data from an exponentially distributed independent variable.
4)  l(c, K, data): calculates per-group MSPE (Mean Squared Prediction Error).
5)  pl(c, K, data, rho): calculates per-group MSPE (Mean Squared Prediction Error) w/ rho-zCDP guarantees.
6)  pl_se(c, K, data, rho, s1, s2): for 2 groups, uses standard errors -- s1, s2 -- to distribute budget of rho.
7)  grad_l(c, K, data): calculates gradient of per-group MSPE.
8)  pgrad_l(c, K, data, rho): calculates gradient of per-group MSPE using budget of rho.
9)  pgrad_l_se(c, K, data, rho, s1, s2): uses standard error information to calculate pgrad_l.
10) f(c, K, data): sums all losses for all groups.
11) grad_f(c, K, data): compute gradient of f.
12) pgrad_f(c, K, data): privately compute gradient of f.
13) g(c, K, data): compute loss for smallest group.
14) grad_g(c, K, data): compute gradient of g.
15) pgrad_g(c, K, data, rho): privately compute gradient of g using rho-zCDP.
16) smooth_max: computes smooth maximum.
17) pgdcgoo: compute with privacy and fairness.
18) pgdcgoo_nonfair: compute with privacy w/o fairness.
19) pgdcgoo_nonfair_se: compute pgdcgoo_nonfair w/ standard error information.
22) mspe: computes Mean Squared Prediction Error.
23) standard_errors: Computes the standard error.
24) synthetic_datasets: Generates synthetic datasets for 2 groups of size (n1, n2)
    and slopes (slope1, slope2) using trial runs of 'num_trails' and variance vare
    for the dependent variable.
25) simulation: goes through rhos, runs functions, and collects results.

Other functions that go through real-world datasets (instead of synthetic): real_datasets, getlawschooldata

'''

#
# data = {(x_i, y_i, a_i)}
#

random.seed(2021)
np.random.seed(seed=2021)

# Generates distribution for X, Y
# n is number of samples
# varx is variance of x
# mean of samples
# variance of the
# function that generates linear regression data and coefficients.
def gen_data_lin_norm(n, varx, barx, vare, slope, intercept, group_num):
    n = int(n)
    x = np.random.normal(barx, math.sqrt(varx), n)
    X = np.vstack((np.array(x), np.ones(len(x))))
    XG = np.vstack((X, np.ones(len(x), dtype=int)*group_num))
    X = np.transpose(X)
    XG = np.transpose(XG)
    Y = np.random.normal(np.matmul(X, [slope, intercept]), math.sqrt(vare))
    return (XG, Y)

# Calculates loss for each group
# returns vec which is loss for each group.
def l(c, K, XG, Y, sizes):
    vec = np.zeros(K)
    X = XG[:,:-1]
    cur = 0
    for k in range(K):
        Xk = X[cur:cur+sizes[k],:]
        Yk = Y[cur:cur+sizes[k]]
        cur = cur+sizes[k]

        vec[k] = np.sum(np.dot(Yk-np.matmul(Xk, c), Yk-np.matmul(Xk, c)))
    return vec

# calculates the gradient of l(c, K, data)
def grad_l(c, K, XG, Y, sizes):
    P = len(c)
    vec = np.zeros(K*P).reshape(K, P)
    X = XG[:,:-1]

    cur = 0

    for k in range(K):
        Xk = X[cur:cur+sizes[k],:]
        Xksub = np.random.choice(len(Xk), int(len(Xk)*0.1), replace=True)
        Xk = Xk[Xksub]
        Yk = Y[cur:cur+sizes[k]]
        Yk = Yk[Xksub]
        cur = cur+sizes[k]

        L = np.matmul(np.diag(-2*(Yk-np.matmul(Xk, c))), Xk)

        vec[k] = np.matmul(np.ones(len(Xksub)), L)

    return vec.transpose()

# calculates the gradient of l(c, K, data) with differential privacy and standard errors
def pgrad_l_se(c, K, XG, Y, sizes, rho, splits):
    Delta = 2
    sigsq = Delta**2/(2*rho)
    P = len(c)
    vec = np.zeros(K*P).reshape(K, P)
    X = XG[:,:-1]

    cur = 0

    for k in range(K):
        Xk = X[cur:cur+sizes[k],:]
        Xksub = np.random.choice(len(Xk), int(len(Xk)*0.1), replace=True)
        Xk = Xk[Xksub]
        Yk = Y[cur:cur+sizes[k]]
        Yk = Yk[Xksub]
        cur = cur+sizes[k]

        L = np.matmul(np.diag(-2*(Yk-np.matmul(Xk, c))), Xk)
        # l2-norm clipping
        L = np.transpose(np.transpose(L)/np.transpose(np.maximum(1, np.linalg.norm(L, axis=1)/Delta)))

        vec[k] = np.matmul(np.ones(len(Xksub)), L)

        vec[k] = vec[k] + np.random.normal(0, sigsq*splits[k], P)

    return vec.transpose()

# calculates the gradient of l(c, K, data) with differential privacy
def pgrad_l(c, K, XG, Y, sizes, rho):
    return pgrad_l_se(c, K, XG, Y, sizes, rho, np.ones(len(sizes))/len(sizes))

# private version of l(c, K, data)
# private version of the loss
# delta is the clipping parameter
# rho is the privacy budget which is eps^2/2
def pl(c, K, XG, Y, sizes, rho):
    return pl_se(c, K, XG, Y, sizes, rho, np.ones(len(sizes))/len(sizes))

# private version of l(c, K, data) that uses standard errors
def pl_se(c, K, XG, Y, sizes, rho, splits):
    Delta = 2 # is  the clipping parameter
    sigsq = Delta**4/(2*rho)
    P = len(c)
    vec = np.zeros(K)
    X = XG[:,:-1]

    cur = 0

    for k in range(K):
        Xk = X[cur:cur+sizes[k],:]
        Yk = Y[cur:cur+sizes[k]]
        cur = cur+sizes[k]

        vec[k] = np.sum(np.clip(np.dot(Yk-np.matmul(Xk, c), Yk-np.matmul(Xk, c)), 0, Delta**2)) + np.random.normal(0, sigsq*splits[k])

    return vec

# calculates the gradient of f on the dataset with differential privacy
def pgrad_f(c, vec):
    K = len(vec)
    return np.ones(K)

# calculates the value of g on the dataset
# assume smallest group is second group (indexed by 1)
def g(c, K, XG, Y, sizes):
    vec = l(c, K, XG, Y, sizes)
    return vec[1]

# calculates the differentially private gradient of g on the dataset
# assume smallest group is second group (indexed by 1)
def pgrad_g(c, vec):
    K = len(vec)
    if K == 2:
        vec_l = vec
        big_vec = np.zeros(len(vec))
        big_vec[1] = 1
        return big_vec
    else:
        return grad_smooth_max(vec, 1)

# calculates the smooth maximum
def smooth_max(vec, eta):
    vec = np.array(vec)
    den = np.sum(np.exp(eta*vec))
    return np.sum(vec*np.exp(eta*vec))/den

# calculates the gradient of smooth maximum
def grad_smooth_max(vec, eta):
    return np.exp(eta*vec)/np.sum(np.exp(eta*vec)) * (1 + eta*(vec - smooth_max(vec, eta)))

# privacy without fairness (w/ se)
def pgdcgoo_nonfair_se(start_c, T, rho, ell_f, ell_g, K, G, P, XG, Y, sizes, splits):
    c = np.copy(start_c)
    rho = rho/T
    for t in range(T):
        private_grad_l = pgrad_l_se(c, K, XG, Y, sizes, rho/2, splits)
        private_l = pl_se(c, K, XG, Y, sizes, rho/2, splits)

        grad_h = np.matmul(private_grad_l, pgrad_f(c, private_l))
        eta = 1/math.sqrt(np.sum((grad_h)**2))

        c = np.clip(c - eta * grad_h, -10, 10)
    return c

# our main contribution -- privacy + fairness (w/ se)
def pgdcgoo_se(start_c, T, rho, ell_f, ell_g, K, G, P, XG, Y, sizes, splits):
    c = np.copy(start_c)
    rho = rho/T
    for t in range(T):
        private_grad_l = pgrad_l_se(c, K, XG, Y, sizes, rho/2, splits)
        private_l = pl_se(c, K, XG, Y, sizes, rho/2, splits)

        Grad_g = 0 if  g(c, K, XG, Y, sizes) <= 0 else np.matmul(private_grad_l, pgrad_g(c, private_l))
        grad_h = np.matmul(private_grad_l, pgrad_f(c, private_l)) + G*Grad_g
        eta = 1/math.sqrt(np.sum((grad_h)**2))

        c = np.clip(c - eta * grad_h, -10, 10)
    return c

# privacy + fairness
def pgdcgoo(start_c, T, rho, ell_f, ell_g, K, G, P, XG, Y, sizes):
    return pgdcgoo_se(start_c, T, rho, ell_f, ell_g, K, G, P, XG, Y,
                      sizes,
                      np.ones(len(sizes))/len(sizes))

# privacy without fairness
def pgdcgoo_nonfair(start_c, T, rho, ell_f, ell_g, K, G, P, XG, Y, sizes):
    return pgdcgoo_nonfair_se(start_c, T, rho, ell_f, ell_g, K, G, P, XG,
                              Y, sizes, np.ones(len(sizes))/len(sizes))

# calculates mean squared prediction error
def mspe(c, XG, Y, sizes):
    X = XG[:,:-1]
    error = 0
    return np.dot(Y-np.matmul(X, c), Y-np.matmul(X, c))/np.sum(sizes)

# calculates mean squared prediction error for each group
def mspe_all_groups(c, XG, Y, sizes):
    X = XG[:,:-1]
    K = len(sizes)

    errors = np.zeros(K)
    cur = 0

    for k in range(K):
        Xk = X[cur:cur+sizes[k],:]
        Yk = Y[cur:cur+sizes[k]]
        cur = cur+sizes[k]

        errors[k] = np.dot(Yk-np.matmul(Xk, c), Yk-np.matmul(Xk, c))/sizes[k]

    return errors

# calculates standard errors
# currently implemented for simple linear regression alone but can be generalized easily
def standard_errors(XG, Y, sizes, group_nums, rho):
    Delta = 2
    K = len(group_nums)
    rho = rho/(K+1)
    s_e = np.zeros(K)
    X = XG[:,:-1]
    XX = np.matmul(X.transpose(), X)
    XY = np.matmul(X.transpose(), Y)
    nvarn = XX[1, 1] + np.random.normal(0, Delta**2/(2*rho))
    if nvarn <= 0:
        return None
    XX[1, 1] = nvarn
    hat_beta = np.matmul(np.linalg.inv(XX), np.matmul(X.transpose(), Y))

    cur = 0
    for k in range(K):
        Xk = X[cur:cur+sizes[k],:]
        Yk = Y[cur:cur+sizes[k]]
        cur = cur+sizes[k]

        s_e[k] = np.sum(np.clip(np.dot(Yk-np.matmul(Xk, hat_beta), Yk-np.matmul(Xk, hat_beta)), 0, Delta**2)) + np.random.normal(0, Delta**4/(2*rho))
        if s_e[k] <= 0:
            return None
        s_e[k] = math.sqrt(s_e[k])/math.sqrt(sizes[k])

    neg_se_norm = (np.array(s_e))**2
    splits = neg_se_norm/np.sum(neg_se_norm)

    return (s_e, splits)

def synthetic_datasets(n1, n2, slope1, slope2, num_trials, vare):
    (XG1, Y1) = gen_data_lin_norm(n1, 0.05, 0, vare, slope1, 5, 0)
    (XG2, Y2) = gen_data_lin_norm(n2, 0.05, 0, vare, slope2, 5, 1)
    XG = np.concatenate([XG1, XG2])
    Y = np.concatenate([Y1, Y2])

    eps = np.arange(2, 11)
    rhos = eps**2/2

    (rho_mspe_results_0, rho_smaller_results_0, rho_larger_results_0,
     rho_mspe_results_1, rho_smaller_results_1, rho_larger_results_1,
     rho_mspe_results_0_std, rho_smaller_results_0_std, rho_larger_results_0_std,
     rho_mspe_results_1_std, rho_smaller_results_1_std, rho_larger_results_1_std,
     all_mspe_results, all_smaller_results, all_larger_results) = simulation(XG, Y, 2, [n1, n2], num_trials, rhos, 2)

    np.save("sgd_vectors_se{4}/all_mspe_results_{0}_{1}_{2}_{3}_{4}_{5}.npy".format(n1, n2, slope1, slope2, num_trials, vare), all_mspe_results)
    np.save("sgd_vectors_se{4}/all_smaller_results_{0}_{1}_{2}_{3}_{4}_{5}.npy".format(n1, n2, slope1, slope2, num_trials, vare), all_smaller_results)
    np.save("sgd_vectors_se{4}/all_larger_results_{0}_{1}_{2}_{3}_{4}_{5}.npy".format(n1, n2, slope1, slope2, num_trials, vare), all_larger_results)

def simulation(XG, Y, P, sizes, num_trials, rhos, K):
    rho_mspe_results_0 = []
    rho_smaller_results_0 = []
    rho_larger_results_0 = []
    rho_mspe_results_1 = []
    rho_smaller_results_1 = []
    rho_larger_results_1 = []
    rho_mspe_results_0_std = []
    rho_smaller_results_0_std = []
    rho_larger_results_0_std = []
    rho_mspe_results_1_std = []
    rho_smaller_results_1_std = []
    rho_larger_results_1_std = []

    all_mspe_results = []
    all_smaller_results = []
    all_larger_results = []

    for rho in rhos:
        mspe_results = [[], []]
        smaller_results = [[], []]
        larger_results = [[], []]
        l_f = 2
        K = 2
        alpha = 1
        G = (alpha + l_f*math.sqrt(K))/alpha
        T = 100
        betaf = betag = 2
        s_e = np.ones(K)
        splits = np.ones(K)/K
        for num in range(num_trials):
            res = standard_errors(XG, Y, sizes, list(range(K)), 1.0/5*rho)
            if res == None:
                continue
            else:
                (s_e, splits) = res
            c = np.random.uniform(-10, 10, P)
            nonfair_c = pgdcgoo_nonfair(c, T, rho, betaf, betag, 2, G, 2, XG, Y, sizes)
            mspe_results[0].append(mspe(nonfair_c, XG, Y, sizes))
            all_mse = mspe_all_groups(nonfair_c, XG, Y, sizes)
            smaller_results[0].append(all_mse[1])
            larger_results[0].append(all_mse[0])
            c = np.random.uniform(-10, 10, P)
            fair_c = pgdcgoo_se(c, T, 4.0/5*rho, betaf, betag, 2, G, 2, XG, Y, sizes, splits)
            mspe_results[1].append(mspe(fair_c, XG, Y, sizes))
            all_mse = mspe_all_groups(fair_c, XG, Y, sizes)
            smaller_results[1].append(all_mse[1])
            larger_results[1].append(all_mse[0])

            if num%10 == 0:
                print("Trial Number: {0}".format(num))
        print("MSPE with privacy without fairness:", np.mean(mspe_results[0]))
        print("Smaller with privacy without fairness:", np.mean(smaller_results[0]))
        print("Larger with privacy without fairness:", np.mean(larger_results[0]))
        print("MSPE with privacy with fairness (w/ SE):", np.mean(mspe_results[1]))
        print("Smaller with privacy with fairness (w/ SE):", np.mean(smaller_results[1]))
        print("Larger with privacy with fairness (w/ SE):", np.mean(larger_results[1]))
        print("rho =", rho)
        print("="*50)

        rho_mspe_results_0.append(np.mean(mspe_results[0]))
        rho_smaller_results_0.append(np.mean(smaller_results[0]))
        rho_larger_results_0.append(np.mean(larger_results[0]))
        rho_mspe_results_1.append(np.mean(mspe_results[1]))
        rho_smaller_results_1.append(np.mean(smaller_results[1]))
        rho_larger_results_1.append(np.mean(larger_results[1]))
        rho_mspe_results_0_std.append(np.std(mspe_results[0]))
        rho_smaller_results_0_std.append(np.std(smaller_results[0]))
        rho_larger_results_0_std.append(np.std(larger_results[0]))
        rho_mspe_results_1_std.append(np.std(mspe_results[1]))
        rho_smaller_results_1_std.append(np.std(smaller_results[1]))
        rho_larger_results_1_std.append(np.std(larger_results[1]))

        all_mspe_results.append(mspe_results)
        all_smaller_results.append(smaller_results)
        all_larger_results.append(larger_results)

    return (rho_mspe_results_0, rho_smaller_results_0, rho_larger_results_0,
            rho_mspe_results_1, rho_smaller_results_1, rho_larger_results_1,
            rho_mspe_results_0_std, rho_smaller_results_0_std, rho_larger_results_0_std,
            rho_mspe_results_1_std, rho_smaller_results_1_std, rho_larger_results_1_std,
            all_mspe_results, all_smaller_results, all_larger_results)

def getlawschooldata():
    # breakdown by race
    # race = 0, 1, 2, 3, 4, 5, 6, 7
    dflaw = pd.read_csv('./data/lawschool.csv')
    dflaw = dflaw.dropna()
    complete_data = []
    for index, row in dflaw.iterrows():
        cluster = row["cluster"]
        lsat = row["lsat"]
        ugpa = row["ugpa"]
        age = row["age"]
        fam_inc = row["fam_inc"]
        firstyrgpa = row["zfygpa"]
        gpa = row["zgpa"]
        race = row["race"]
        features = np.array([cluster, lsat, ugpa, age, fam_inc, firstyrgpa])
        complete_data.append([features, float(gpa), int(race-1)])
    return complete_data

def real_datasets(num_trials, text):
    XG = []
    Y = []
    K = 2
    P = 3
    n1 = 0
    n2 = 0
    if text == "lawschool":
        old_data = getlawschooldata()
        P = len(old_data[0][0])+1
        for row in old_data:
            x_i = np.concatenate((row[0], [1]))
            x_i = x_i**2/np.sum(x_i**2)
            if row[2] == 2:
                Y.append(row[1])
                XG.append(np.concatenate((x_i, [1])))
                n2 += 1
            elif row[2] == 6:
                Y.append(row[1])
                XG.append(np.concatenate((x_i, [0])))
                n1 += 1
        Y = np.array(Y)
        XG = np.array(XG)
        K = 2

    eps = np.arange(2, 11)
    rhos = eps**2/2

    (rho_mspe_results_0, rho_smaller_results_0, rho_larger_results_0,
     rho_mspe_results_1, rho_smaller_results_1, rho_larger_results_1,
     rho_mspe_results_0_std, rho_smaller_results_0_std, rho_larger_results_0_std,
     rho_mspe_results_1_std, rho_smaller_results_1_std, rho_larger_results_1_std,
     all_mspe_results, all_smaller_results, all_larger_results) = simulation(XG, Y, P, [n1, n2], num_trials, rhos, K)

    np.save("sgd_vectors_se_{0}/all_mspe_results.npy".format(text), all_mspe_results)
    np.save("sgd_vectors_se_{0}/all_smaller_results.npy".format(text), all_smaller_results)
    np.save("sgd_vectors_se_{0}/all_larger_results.npy".format(text), all_larger_results)

if __name__ == "__main__":
    num_trials = 1000
    synthetic_datasets(10000, 500, -10, 10, num_trials, 1)
