import numpy as np
import tensorflow as tf
from utils import generateTwoGroupsData
from fair_training import train_fair_nn

alphaMin = 0.2
n = 150
X_train, y_train, z_train = generateTwoGroupsData(alphaMin, n, seed=0)
sens_directions = np.array([1,0]).reshape(1,-1)

fair_info = [None, None, None, None, sens_directions, None]

## Train SenSeI with different fair regularization strength
for fair_reg in [0.1, 0.5, 1., 2., 3., 4., 5.]:
    tf.reset_default_graph()
    weights, train_logits, test_logits, _  = train_fair_nn(X_train, y_train, tf_prefix='sensei', adv_epoch_full=10, l2_attack=0.01,
                                              adv_epoch=10, ro=0.1, adv_step=0.1, plot=False, fair_info=fair_info, balance_batch=False, 
                                              X_test = None, X_test_counter=None, y_test = None, lamb_init=2., 
                                              n_units=[10], l2_reg=0, epoch=1000, batch_size=n, lr=0.01, fair_reg=fair_reg,
                                              fair_start=0., counter_init=False, seed=None, simul=True)