import numpy as np
import pandas as pd
import scipy
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from cvxpy import *
from DTools3 import *
from matplotlib import gridspec
from matplotlib.patches import Rectangle
import os, sys
# %matplotlib inline


        
seed = sum([ord(b) for b in 'Bhanu'])
np.random.seed(seed = seed)
####################################

def randomize(df, dfMap,features=[]):
    df2 = df.copy()
    print('Randomizing...')
    for idx in tqdm(df2.index):
        rowTest = df2.loc[idx,:]
        vals = rowTest[features]
        draw = dfMap.loc[tuple(vals.tolist())]
        #randomly select value

        mapVal = np.random.choice(range(len(draw)),p=draw.tolist())
        draw.index[mapVal]
        df2.loc[idx,draw.index.names] = draw.index[mapVal]
        
    return df2

def wasserstein_fn(X, Y):

    wasser_dist = np.sqrt(np.average((np.sort(X, axis = None) - np.sort(Y, axis = None)) ** 2))

    return wasser_dist


print(os.getcwd())

df = pd.read_csv(
    "data/adult.data",
    names=[
        "Age", "Workclass", "fnlwgt", "Education", "Education-Num", "Martial Status",
        "Occupation", "Relationship", "Race", "Gender", "Capital Gain", "Capital Loss",
        "Hours per week", "Country", "Income"],
        na_values="?")

# print(df)

df['Age (decade)'] = df['Age'].apply(lambda x: np.floor(x/10.0)*10.0)
df['Age (decade)'] = df['Age'].apply(lambda x: np.floor(x/10.0)*10.0)

def group_edu(x):
    if x<=5:
        return '<6'
    elif x>=13:
        return '>12'
    else:
        return x
    
def age_cut(x):
    if x>=70:
        return '>=70'
    else:
        return x

# Limit education range
df['Education Years'] = df['Education-Num'].apply(lambda x : group_edu(x))

# Limit age range
df['Age (decade)'] = df['Age (decade)'].apply(lambda x : age_cut(x))

# Transform all that is non-white into 'minority'
df['Race'] = df['Race'].apply(lambda x: x if x== ' White' else 'Minority')

# Add binary income variable
df['Income Binary'] = df['Income'].apply(lambda x : 1 if x == " >50K" else 0)

# print(df)

#features = ['Age (decade)','Education Years','Income','Gender','Race','Income Binary']
features = ['Age (decade)','Education Years','Income','Gender','Income Binary']
#D_features = ['Gender','Race']
D_features = ['Gender']
Y_features = ['Income Binary']
X_features = ['Age (decade)', 'Education Years']

# keep only the features we will use
df = df[features]

from sklearn.model_selection import ShuffleSplit
rs = ShuffleSplit(n_splits=10, test_size=2000, random_state=888)  ### CHANGE SEED FOR DIFFERENT SPLITS!
df_list = []
for train_index,test_index in rs.split(df):
    df_list.append((df.iloc[train_index,:].copy(),df.iloc[test_index,:].copy()))

c1 = .99 # value of (delta1,c1): to keep.
c2 = 1.99  # value of (delta2,c2): value that should no happen
c3 = 2.99 # penalty for adjusting age
clist = [c1,c2, c3]
Dclass = Dclass()

# these were the values used in the paper
dlist = [.1,0.05,0]
epsilon = .00

result_folder = '../experiment_data1/'
split_num = 0

results = {
    'OPPDP': {'wasserstein': [], 'uf': [], 'tv':[]}
}

# iterate over pairs
for (df_train,df_test) in df_list:
    file_name = str(split_num)
    
    print('-----------------')
    print('Current split: '+file_name)

    # initialize a new DT object
    DT = DTools(df=df_train,features=features)

    # Set features
    DT.setFeatures(D=D_features,X=X_features,Y=Y_features)

    # Set Distortion
    DT.setDistortion(Dclass, clist=clist)

    # solve optimization for previous parameters -- This uses and older implementation, based on the FATML submission.
    DT.optimize(epsilon=epsilon,dlist = dlist,verbose=False)

    DT.computeMarginals()

    # randomized mapping for training
    # this is the dataframe with the randomization for the train set
    dfPtrain = DT.dfP.applymap(lambda x : 0 if x<1e-8 else x)
    dfPtrain = dfPtrain.divide(dfPtrain.sum(axis=1),axis=0)

    # randomized mapping for testing (Beware of ugly code)
    d1 = DT.dfFull.reset_index().groupby(D_features+X_features).sum()
    d2 = d1.transpose().reset_index().groupby(X_features).sum()
    dTest = d2.transpose()
    dTest = dTest.drop(Y_features,1)
    dTest = dTest.applymap(lambda x: x if x>1e-8 else 0)
    dTest = dTest/dTest.sum()

    # this is the dataframe with the randomization for the test set
    dfPtest = dTest.divide(dTest.sum(axis=1),axis=0)

    # Randomize train data
    print('Randomizing training set...')
    df_train_new = randomize(df_train,dfPtrain,features = D_features+X_features+Y_features)
    # print(df_train_new)
    oppdp_wasser = wasserstein_fn(df_train_new['Income Binary'], df_train['Income Binary'])
    # print(oppdp_wasser)
    results['OPPDP']['wasserstein'].append(oppdp_wasser)
    
    y_compare = np.mean(df_train['Income Binary']) + np.random.normal(loc=0, scale=np.std(df_train['Income Binary']), size=len(df_train['Income Binary']))
    # y_compare_new = Y_pred + np.random.normal(loc=0, scale=np.std(y_train), size=len(y_train))
    syn_data_1_oppdp = df_train_new['Income Binary'][df_train_new['Gender'] == ' Male'].astype(float)
    syn_data_0_oppdp = df_train_new['Income Binary'][df_train_new['Gender'] == ' Female'].astype(float)
    y_compare_1 = (y_compare[df_train_new['Gender'] == ' Male']> 0.5).astype(float)
    y_compare_0 = (y_compare[df_train_new['Gender'] == ' Female']> 0.5).astype(float)

    uf_oppdp = 0.5 * wasserstein_fn(syn_data_1_oppdp, y_compare_1) + 0.5 * wasserstein_fn(syn_data_0_oppdp, y_compare_0)
    results['OPPDP']['uf'].append(uf_oppdp)

    tv = np.abs(np.mean(syn_data_0_oppdp) - np.mean(syn_data_1_oppdp))
    results['OPPDP']['tv'].append(tv)

    for mod in results.keys():
        print(f'{mod}: {results[mod]}', ',', sep ='')



orig_stdout = sys.stdout
f = open('oppdp_wasser_nips.txt', 'w')
sys.stdout = f

for mod in results.keys():
        print(f'{mod}: {results[mod]}', ',', sep ='')

sys.stdout = orig_stdout
f.close()

