import numpy as np
from scipy.stats import norm
import scipy.stats 
import pandas as pd
from sklearn.model_selection import train_test_split
import statsmodels.api as sm

from data import load_adult, preprocess_adult
from train import train_decaf, train_fairgan

import sys

# Define DAG for Adult dataset
DAG = [
    # Edges from race
    ['race', 'occupation'],
    ['race', 'income'],
    ['race', 'hours-per-week'],
    ['race', 'education'],
    ['race', 'marital-status'],

    # Edges from age
    ['age', 'occupation'],
    ['age', 'hours-per-week'],
    ['age', 'income'],
    ['age', 'workclass'],
    ['age', 'marital-status'],
    ['age', 'education'],
    ['age', 'relationship'],
    
    # Edges from sex
    ['sex', 'occupation'],
    ['sex', 'marital-status'],
    ['sex', 'income'],
    ['sex', 'workclass'],
    ['sex', 'education'],
    ['sex', 'relationship'],
    
    # Edges from native country
    ['native-country', 'marital-status'],
    ['native-country', 'hours-per-week'],
    ['native-country', 'education'],
    ['native-country', 'workclass'],
    ['native-country', 'income'],
    ['native-country', 'relationship'],
    
    # Edges from marital status
    ['marital-status', 'occupation'],
    ['marital-status', 'hours-per-week'],
    ['marital-status', 'income'],
    ['marital-status', 'workclass'],
    ['marital-status', 'relationship'],
    ['marital-status', 'education'],
    
    # Edges from education
    ['education', 'occupation'],
    ['education', 'hours-per-week'],
    ['education', 'income'],
    ['education', 'workclass'],
    ['education', 'relationship'],
    
    # All remaining edges
    ['occupation', 'income'],
    ['hours-per-week', 'income'],
    ['workclass', 'income'],
    ['relationship', 'income'],
]


def dag_to_idx(df, dag):
    """Convert columns in a DAG to the corresponding indices."""

    dag_idx = []
    for edge in dag:
        dag_idx.append([df.columns.get_loc(edge[0]), df.columns.get_loc(edge[1])])

    return dag_idx

def create_bias_dict(df, edge_map):
    """
    Convert the given edge tuples to a bias dict used for generating
    debiased synthetic data.
    """
    bias_dict = {}
    for key, val in edge_map.items():
        bias_dict[df.columns.get_loc(key)] = [df.columns.get_loc(f) for f in val]
    
    return bias_dict

def wasserstein_fn(X, Y):

    wasser_dist = np.sqrt(np.average((np.sort(X, axis = None) - np.sort(Y, axis = None)) ** 2))

    return wasser_dist

# generating synthetic alpha-DP fair data

num_runs = 10

dataset_train = preprocess_adult(load_adult())
dataset_test = preprocess_adult(load_adult(test=True))
dataset = pd.concat([dataset_train, dataset_test])

results = {
    'decaf': {'wasserstein': [], 'uf':[], 'tv': []},
    'cfgan': {'wasserstein': [], 'uf':[], 'tv': []},
    'fairgan': {'wasserstein': [], 'uf':[], 'tv': []},
    'LFR': {'wasserstein': [], 'uf':[], 'tv': []},
    'OPPDP': {'wasserstein': [], 'uf':[], 'tv': []},
    'tabfairgan': {'wasserstein': [], 'uf':[], 'tv': []}
}

model = 'decaf'
for run in range(num_runs):

    dataset_train, dataset_test = train_test_split(
        dataset, test_size=2000, stratify=dataset['income'])
    print(dataset_train['income'])
    X_train, y_train = dataset_train.drop(columns=['income']), dataset_train['income']

    # logit_model = sm.Logit(y_train, X_train)
    # result = logit_model.fit()

    # baseline_res = np.vstack((result.params, result.bse, result.conf_int(alpha=0.05, cols=None).iloc[:, 0], result.conf_int(alpha=0.05, cols=None).iloc[:, 1]))

    
    dag_seed = dag_to_idx(dataset, DAG)
    train_kwargs = {}
    train_kwargs['dag_seed'] = dag_seed

    bias_dict_dp = create_bias_dict(dataset, {'income': [
        'occupation', 'hours-per-week', 'marital-status', 'education', 'sex',
        'workclass', 'relationship']})
    bias_dicts = {'dp': bias_dict_dp}
    
    for bias_dict in bias_dicts.keys():
        train_kwargs['biased_edges'] = bias_dicts[bias_dict]
        synth_data_decaf = train_decaf(
            dataset_train,
            model_name=f'{model}_experiment_1_run_{run+1}',
            **train_kwargs)
    decaf_wasser = wasserstein_fn(synth_data_decaf['income'], dataset_train['income'])

    
    y_compare = np.mean(y_train) + np.random.normal(loc=0, scale=np.std(y_train), size=len(y_train))
    # y_compare_new = Y_pred + np.random.normal(loc=0, scale=np.std(y_train), size=len(y_train))
    syn_data_1_decaf = synth_data_decaf['income'][synth_data_decaf['sex'] == 1]
    syn_data_0_decaf = synth_data_decaf['income'][synth_data_decaf['sex'] == 0]
    y_compare_1 = (y_compare[synth_data_decaf['sex'] == 1]> 0.5).astype(float)
    y_compare_0 = (y_compare[synth_data_decaf['sex'] == 0]> 0.5).astype(float)

    uf_decaf = 0.5 * wasserstein_fn(syn_data_1_decaf, y_compare_1) + 0.5 * wasserstein_fn(syn_data_0_decaf, y_compare_0)
    results['decaf']['wasserstein'].append(decaf_wasser)
    results['decaf']['uf'].append(uf_decaf)

    tv = np.abs(np.mean(syn_data_0_decaf) - np.mean(syn_data_1_decaf))
    results['decaf']['tv'].append(tv)
    
    for mod in results.keys():
        print(f'{mod}: {results[mod]}', ',', sep ='')

    

orig_stdout = sys.stdout
f = open('decaf_wasser_nips.txt', 'w')
sys.stdout = f

for mod in results.keys():
        print(f'{mod}: {results[mod]}', ',', sep ='')

sys.stdout = orig_stdout
f.close()
