import numpy as np
from scipy.stats import norm
import scipy.stats 
import pandas as pd
from sklearn.model_selection import train_test_split
import statsmodels.api as sm

from data import load_adult, preprocess_adult
from train import train_decaf, train_fairgan

import sys

# Define DAG for Adult dataset
DAG = [
    # Edges from race
    ['race', 'occupation'],
    ['race', 'income'],
    ['race', 'hours-per-week'],
    ['race', 'education'],
    ['race', 'marital-status'],

    # Edges from age
    ['age', 'occupation'],
    ['age', 'hours-per-week'],
    ['age', 'income'],
    ['age', 'workclass'],
    ['age', 'marital-status'],
    ['age', 'education'],
    ['age', 'relationship'],
    
    # Edges from sex
    ['sex', 'occupation'],
    ['sex', 'marital-status'],
    ['sex', 'income'],
    ['sex', 'workclass'],
    ['sex', 'education'],
    ['sex', 'relationship'],
    
    # Edges from native country
    ['native-country', 'marital-status'],
    ['native-country', 'hours-per-week'],
    ['native-country', 'education'],
    ['native-country', 'workclass'],
    ['native-country', 'income'],
    ['native-country', 'relationship'],
    
    # Edges from marital status
    ['marital-status', 'occupation'],
    ['marital-status', 'hours-per-week'],
    ['marital-status', 'income'],
    ['marital-status', 'workclass'],
    ['marital-status', 'relationship'],
    ['marital-status', 'education'],
    
    # Edges from education
    ['education', 'occupation'],
    ['education', 'hours-per-week'],
    ['education', 'income'],
    ['education', 'workclass'],
    ['education', 'relationship'],
    
    # All remaining edges
    ['occupation', 'income'],
    ['hours-per-week', 'income'],
    ['workclass', 'income'],
    ['relationship', 'income'],
]


def dag_to_idx(df, dag):
    """Convert columns in a DAG to the corresponding indices."""

    dag_idx = []
    for edge in dag:
        dag_idx.append([df.columns.get_loc(edge[0]), df.columns.get_loc(edge[1])])

    return dag_idx

def create_bias_dict(df, edge_map):
    """
    Convert the given edge tuples to a bias dict used for generating
    debiased synthetic data.
    """
    bias_dict = {}
    for key, val in edge_map.items():
        bias_dict[df.columns.get_loc(key)] = [df.columns.get_loc(f) for f in val]
    
    return bias_dict

def wasserstein_fn(X, Y):

    wasser_dist = np.sqrt(np.average((np.sort(X, axis = None) - np.sort(Y, axis = None)) ** 2))

    return wasser_dist

# generating synthetic alpha-DP fair data
num_runs = 10
dataset_train = preprocess_adult(load_adult())
dataset_test = preprocess_adult(load_adult(test=True))
dataset = pd.concat([dataset_train, dataset_test])

results = {
    'decaf': {'wasserstein': [], 'uf':[], 'tv':[]},
    'cfgan': {'wasserstein': [], 'uf':[], 'tv':[]},
    'fairgan': {'wasserstein': [], 'uf':[], 'tv':[]},
    'LFR': {'wasserstein': [], 'uf':[], 'tv':[]},
    'OPPDP': {'wasserstein': [], 'uf':[], 'tv':[]},
    'tabfairgan': {'wasserstein': [], 'uf':[], 'tv':[]}
}

model = 'fairgan'
for run in range(num_runs):
    dataset_train, dataset_test = train_test_split(
        dataset, test_size=2000, stratify=dataset['income'])
    print(dataset_train['income'])
    X_train, y_train = dataset_train.drop(columns=['income']), dataset_train['income']
    
    synth_data_fairgan = train_fairgan(dataset_train, model_name=f'{model}_experiment_3_run_{run+1}')
    print(synth_data_fairgan['income'])
    min_len = min(len(synth_data_fairgan['income']), len(dataset_train['income']))
    fairgan_wasser = wasserstein_fn(synth_data_fairgan['income'][:min_len], dataset_train['income'][:min_len])
    y_compare = np.mean(y_train) + np.random.normal(loc=0, scale=np.std(y_train), size=len(y_train))
    # y_compare_new = Y_pred + np.random.normal(loc=0, scale=np.std(y_train), size=len(y_train))
    y_compare_1 = (y_compare[dataset_train['sex'] == 1]> 0.5).astype(float)
    y_compare_0 = (y_compare[dataset_train['sex'] == 0]> 0.5).astype(float)
    syn_data_1_fairgan = synth_data_fairgan['income'][synth_data_fairgan['sex'] >= 0.5]
    syn_data_0_fairgan = synth_data_fairgan['income'][synth_data_fairgan['sex'] < 0.5]
    print(syn_data_0_fairgan)
    min_len_1 = min(len(y_compare_0), len(syn_data_0_fairgan))
    min_len_2 = min(len(y_compare_1), len(syn_data_1_fairgan))
    uf_fairgan = 0.5 * wasserstein_fn(syn_data_1_fairgan[:min_len_2], y_compare_1[:min_len_2]) + 0.5 * wasserstein_fn(syn_data_0_fairgan[:min_len_1], y_compare_0[:min_len_1])

    results['fairgan']['wasserstein'].append(fairgan_wasser)
    results['fairgan']['uf'].append(uf_fairgan)

    tv = np.abs(np.mean(syn_data_0_fairgan) - np.mean(syn_data_1_fairgan))
    results['fairgan']['tv'].append(tv)
    for mod in results.keys():
        print(f'{mod}: {results[mod]}', ',', sep ='')


orig_stdout = sys.stdout
f = open('fairgan_wasser_nips.txt', 'w')
sys.stdout = f

for mod in results.keys():
        print(f'{mod}: {results[mod]}', ',', sep ='')
    
sys.stdout = orig_stdout
f.close()
