import numpy as np
from scipy.stats import norm
from scipy.stats import invgamma
import scipy.stats
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

from data import load_adult, preprocess_adult
from metrics import eval_model
from train import train_decaf
import statsmodels.api as sm

import sys
import time

# generating the augmentation data Z using Y
def aug_fn(dat, sig, M): 
    Y = dat['income']
    n = len(Y)
    
    Z = np.zeros((M, n))  # Initialize an M x n matrix to store the results
    
    for x in range(M):
        Z[x, :] = Y + np.random.normal(loc=0, scale=sig, size=n)
    
    return Z
    

def syn_fn(synth_data, sig, M, dataset_train, eta_sq, sig_M):
  
    Z = aug_fn(dataset_train, sig, M)
    # print(Z)
    # print(np.apply_along_axis(np.sum, 0, Z))
    y_mu = (sig_M * synth_data['income'] + (np.apply_along_axis(np.sum, 0, Z) * eta_sq / 1) ) / (sig_M + eta_sq) # always use M = 1 here
    y_sig = np.sqrt((sig_M * eta_sq) / (sig_M + eta_sq))
    print('y_mu')
    print(y_mu)
    print('y_sig')
    print(y_sig)
    y_syn = (np.array([norm.rvs(loc=x, scale=y_sig, size=1)[0] for x in y_mu])> 0.5).astype(float)
    
    synth_data = dataset_train.copy()
    synth_data['income'] = y_syn

    # print(synth_data)

    return synth_data

def wasserstein_fn(X, Y):

    wasser_dist = np.sqrt(np.average((np.sort(X, axis = None) - np.sort(Y, axis = None)) ** 2))

    return wasser_dist


# generating synthetic alpha-DP fair data
# sig_list = np.sqrt([0.25, 0.5, 1, 2, 4]) # xi = [sigma, M]
# sig_list = np.sqrt([0.1, 0.2]) # xi = [sigma, M]
# sig_list = [] # xi = [sigma, M]
sig_list = np.arange(0.1, 1.1, 0.1)
sig_list = np.append(0.0001, sig_list)
sig_list = np.append(sig_list, 100000000)
print(sig_list)
num_runs = 10
faith_fair = False
downstream = False

dataset_train = preprocess_adult(load_adult())
dataset_test = preprocess_adult(load_adult(test=True))
dataset = pd.concat([dataset_train, dataset_test])

# print(dataset_train)
results = {
    'original': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein': [], 'uf': [], 'tv': []},
    'decaf_dp': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein': [], 'uf': [], 'tv': []},
    'fdami_linear_dp_1': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein': [], 'uf': [], 'tv': []},
    'fdami_linear_dp_2': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein': [], 'uf': [], 'tv': []},
    'fdami_linear_dp_3': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein': [], 'uf': [], 'tv': []},
    'fdami_linear_dp_4': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein':[], 'uf': [], 'tv': []}, 
    'fdami_linear_dp_5': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein': [], 'uf': [], 'tv': []},
    'fdami_linear_dp_6': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein': [], 'uf': [], 'tv': []},
    'fdami_linear_dp_7': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein': [], 'uf': [], 'tv': []},
    'fdami_linear_dp_8': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein': [], 'uf': [], 'tv': []},
    'fdami_linear_dp_9': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein':[], 'uf': [], 'tv': []},
    'fdami_linear_dp_10': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein':[], 'uf': [], 'tv': []},
    'fdami_linear_dp_11': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein':[], 'uf': [], 'tv': []},
    'fdami_linear_dp_12': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'etasq': [], 'wasserstein':[], 'uf': [], 'tv': []},
}


for run in range(num_runs):
    dataset_train, dataset_test = train_test_split(
        dataset, test_size=2000, stratify=dataset['income'])
    X_train, y_train = dataset_train.drop(columns=['income']), dataset_train['income']

    M = 1
    ind = 1
    for sig in sig_list:
        sig_M = sig**2 / M
        
        start_time = time.time()
        
        
        # intercept only fair model
        Y_pred = np.mean(y_train)
        print(Y_pred)
        syn_data_int = dataset_train.copy()
        syn_data_int['income'] = Y_pred
        
        residual = y_train - Y_pred
        eta_sq = invgamma.rvs(a = (X_train.shape[0] - 1) / 2, scale = np.sum(residual ** 2) / 2, size = 1) # sample from inverse-gamma distribution with non-informative prior

        syn_data = syn_fn(synth_data = syn_data_int, sig = sig, M = M, dataset_train = dataset_train, eta_sq = eta_sq, sig_M = sig_M)
        
        new_model = 'fdami_linear_dp'

        results[f'{new_model}_{ind}']['etasq'].append(eta_sq[0])

        print("--- %s seconds ---" % (time.time() - start_time))

        if downstream == True:
            model_results = eval_model(syn_data, dataset_test)
            for key, value in model_results.items():
                results[f'{new_model}_{ind}'][key].append(value)
            # results using the original unfair dataset
            model_results = eval_model(dataset_train, dataset_test)
            org_model = 'original'
            for key, value in model_results.items():
                results[org_model][key].append(value)

        if faith_fair == True:
            print('np.std(y_train)')
            print(np.std(y_train))
            y_compare = Y_pred + np.random.normal(loc=0, scale=np.std(y_train), size=len(y_train))
            syn_data_1 = syn_data['income'][syn_data['sex'] == 1]
            syn_data_0 = syn_data['income'][syn_data['sex'] == 0]
            y_compare_1 = (y_compare[syn_data['sex'] == 1]> 0.5).astype(float)
            y_compare_0 = (y_compare[syn_data['sex'] == 0]> 0.5).astype(float)
            uf = 0.5 * wasserstein_fn(syn_data_1, y_compare_1) + 0.5 * wasserstein_fn(syn_data_0, y_compare_0)
            results[f'{new_model}_{ind}']['uf'].append(uf)

            alpha = 1 / (sig_M / eta_sq + 1)
            print('alpha')
            print(alpha)
            W_dist = np.sqrt( (1 - alpha) ** 2 * np.mean((syn_data['income'] - dataset_train['income']) ** 2) + (1 - alpha ** 2) * eta_sq )
            results[f'{new_model}_{ind}']['wasserstein'].append(W_dist[0])

            tv = np.abs(np.mean(syn_data_0) - np.mean(syn_data_1))
            results[f'{new_model}_{ind}']['tv'].append(tv)


        ind = ind + 1
        for mod in results.keys():
            print(f'{mod}: {results[mod]}', ",", sep ='')


orig_stdout = sys.stdout
# f = open('fda-cmm_full_adult_nips.txt', 'w')
# sys.stdout = f

for mod in results.keys():
        print(f'{mod}: {results[mod]}', ",", sep ='')

# sys.stdout = orig_stdout
# f.close()