import numpy as np
from scipy.stats import norm
from scipy.stats import invgamma
import scipy.stats
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

from data import load_adult, preprocess_adult
from metrics import eval_model
from train import train_decaf
import statsmodels.api as sm

import sys

# Define DAG for Adult dataset
DAG = [
    # Edges from race
    ['race', 'occupation'],
    ['race', 'income'],
    ['race', 'hours-per-week'],
    ['race', 'education'],
    ['race', 'marital-status'],

    # Edges from age
    ['age', 'occupation'],
    ['age', 'hours-per-week'],
    ['age', 'income'],
    ['age', 'workclass'],
    ['age', 'marital-status'],
    ['age', 'education'],
    ['age', 'relationship'],
    
    # Edges from sex
    ['sex', 'occupation'],
    ['sex', 'marital-status'],
    ['sex', 'income'],
    ['sex', 'workclass'],
    ['sex', 'education'],
    ['sex', 'relationship'],
    
    # Edges from native country
    ['native-country', 'marital-status'],
    ['native-country', 'hours-per-week'],
    ['native-country', 'education'],
    ['native-country', 'workclass'],
    ['native-country', 'income'],
    ['native-country', 'relationship'],
    
    # Edges from marital status
    ['marital-status', 'occupation'],
    ['marital-status', 'hours-per-week'],
    ['marital-status', 'income'],
    ['marital-status', 'workclass'],
    ['marital-status', 'relationship'],
    ['marital-status', 'education'],
    
    # Edges from education
    ['education', 'occupation'],
    ['education', 'hours-per-week'],
    ['education', 'income'],
    ['education', 'workclass'],
    ['education', 'relationship'],
    
    # All remaining edges
    ['occupation', 'income'],
    ['hours-per-week', 'income'],
    ['workclass', 'income'],
    ['relationship', 'income'],
]


def dag_to_idx(df, dag):
    """Convert columns in a DAG to the corresponding indices."""

    dag_idx = []
    for edge in dag:
        dag_idx.append([df.columns.get_loc(edge[0]), df.columns.get_loc(edge[1])])

    return dag_idx

def create_bias_dict(df, edge_map):
    """
    Convert the given edge tuples to a bias dict used for generating
    debiased synthetic data.
    """
    bias_dict = {}
    for key, val in edge_map.items():
        bias_dict[df.columns.get_loc(key)] = [df.columns.get_loc(f) for f in val]
    
    return bias_dict

# generating the augmentation data Z using Y
def aug_fn(dat, sig, M): 
    Y = dat['income']
    n = len(Y)
    
    Z = np.zeros((M, n))  # Initialize an M x n matrix to store the results
    
    for x in range(M):
        Z[x, :] = Y + np.random.normal(loc=0, scale=sig, size=n)
    
    return Z
    

def syn_fn(synth_data, sig, M, dataset_train, eta_sq, sig_M):
  
    Z = aug_fn(dataset_train, sig, M)
    # print(Z)
    # print(np.apply_along_axis(np.sum, 0, Z))
    y_mu = (sig_M * synth_data['income'] + (np.apply_along_axis(np.sum, 0, Z) * eta_sq / 1) ) / (sig_M + eta_sq) # always use M = 1 here
    y_sig = np.sqrt((sig_M * eta_sq) / (sig_M + eta_sq))
    print('y_mu')
    print(y_mu)
    print('y_sig')
    print(y_sig)
    y_syn = (np.array([norm.rvs(loc=x, scale=y_sig, size=1)[0] for x in y_mu])> 0.5).astype(float)
    
    synth_data = dataset_train.copy()
    synth_data['income'] = y_syn

    # print(synth_data)

    return synth_data

def wasserstein_fn(X, Y):

    wasser_dist = np.sqrt(np.average((np.sort(X, axis = None) - np.sort(Y, axis = None)) ** 2))

    return wasser_dist


# generating synthetic alpha-DP fair data
# sig_list = np.sqrt([0.25, 0.5, 1, 2, 4]) # xi = [sigma, M]
# sig_list = np.sqrt([0.1, 0.2]) # xi = [sigma, M]
# sig_list = [] # xi = [sigma, M]
sig_list = np.arange(0.1, 1.1, 0.1)
sig_list = np.append(0.0001, sig_list)
sig_list = np.append(sig_list, 100000000)
print(sig_list)
num_runs = 10
faith_fair = True
downstream = True

dataset_train = preprocess_adult(load_adult())
dataset_test = preprocess_adult(load_adult(test=True))
dataset = pd.concat([dataset_train, dataset_test])

# print(dataset_train)
results = {
    'original': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'decaf_dp': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_1': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_2': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_3': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_4': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_5': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_6': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_7': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_8': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_9': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_10': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_11': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
    'fdami_decaf_dp_12': {'precision': [], 'recall': [], 'auroc': [], 'dp': [], 'ftu': [], 'wasserstein': [], 'diff_mean': [], 'overlap_min':[]},
}


for run in range(num_runs):
    dataset_train, dataset_test = train_test_split(
        dataset, test_size=2000, stratify=dataset['income'])
    X_train, y_train = dataset_train.drop(columns=['income']), dataset_train['income']

    M = 1
    ind = 1
    for sig in sig_list:
        sig_M = sig**2 / M
        eta_sq = 1

        Y_pred = np.mean(y_train)

        new_model = 'decaf_dp'
        dag_seed = dag_to_idx(dataset, DAG)
        train_kwargs = {}
        train_kwargs['dag_seed'] = dag_seed

        bias_dict_dp = create_bias_dict(dataset, {'income': [
            'occupation', 'hours-per-week', 'marital-status', 'education', 'sex',
            'workclass', 'relationship']})
        bias_dicts = {'dp': bias_dict_dp}
        
        for bias_dict in bias_dicts.keys():
            train_kwargs['biased_edges'] = bias_dicts[bias_dict]
            syn_data = train_decaf(
                dataset_train,
                model_name=f'{new_model}_experiment_1_run_{run+1}',
                **train_kwargs)
        
        results[f'{new_model}_{ind}']['etasq'].append(eta_sq[0])

        if downstream == True:
            model_results = eval_model(syn_data, dataset_test)
            for key, value in model_results.items():
                results[f'{new_model}_{ind}'][key].append(value)
            # results using the original unfair dataset
            model_results = eval_model(dataset_train, dataset_test)
            org_model = 'original'
            for key, value in model_results.items():
                results[org_model][key].append(value)

        if faith_fair == True:
            print('np.std(y_train)')
            print(np.std(y_train))
            y_compare = Y_pred + np.random.normal(loc=0, scale=np.std(y_train), size=len(y_train))
            syn_data_1 = syn_data['income'][syn_data['sex'] == 1]
            syn_data_0 = syn_data['income'][syn_data['sex'] == 0]
            y_compare_1 = (y_compare[syn_data['sex'] == 1]> 0.5).astype(float)
            y_compare_0 = (y_compare[syn_data['sex'] == 0]> 0.5).astype(float)
            uf = 0.5 * wasserstein_fn(syn_data_1, y_compare_1) + 0.5 * wasserstein_fn(syn_data_0, y_compare_0)
            results[f'{new_model}_{ind}']['uf'].append(uf)

            alpha = 1 / (sig_M / eta_sq + 1)
            print('alpha')
            print(alpha)
            W_dist = np.sqrt( (1 - alpha) ** 2 * np.mean((syn_data['income'] - dataset_train['income']) ** 2) + (1 - alpha ** 2) * eta_sq )
            results[f'{new_model}_{ind}']['wasserstein'].append(W_dist[0])

            tv = np.abs(np.mean(syn_data_0) - np.mean(syn_data_1))
            results[f'{new_model}_{ind}']['tv'].append(tv)


        ind = ind + 1
        for mod in results.keys():
            print(f'{mod}: {results[mod]}', ",", sep ='')


orig_stdout = sys.stdout
f = open('fda-cmm_full_adult_nips.txt', 'w')
sys.stdout = f

for mod in results.keys():
        print(f'{mod}: {results[mod]}', ",", sep ='')

sys.stdout = orig_stdout
f.close()