import numpy as np
import math, os
import random
import copy
import torch
import torch.nn as nn
import torch.utils.data
from tqdm import tqdm
from tensorboardX import SummaryWriter

from models import convert_vars_to_gpu
from utils import logsumexp, shuffle_combined, exp_lr_scheduler, get_optimizer, serialize, transform_dataset, create_group_mask_tensor, gini_indices
from evaluation import compute_dcg_rankings, evaluate_soft_model_owa, multiple_sample_and_log_probability, compute_dcg_max

from fairness_loss import GroupFairnessLoss, BaselineAshudeepGroupFairnessLoss, get_group_merits, get_group_identities, get_group_quantiles  #JK
from frank_wolfe import FWS, FWS_batch, FWS_batch_fast
from networksJK import PolicyLP, PolicyLP_Plus, PolicyLP_PlusNeq, PolicyLP_PlusSP, PolicyBlackboxWrapper, create_torch_LP
from birkhoff import birkhoff_von_neumann_decomposition
import time
from models import LinearModel, init_weights # JK
import pickle as pkl
import pandas as pd
from ort_rank import *
from fairness_loss import test_fairness
import matplotlib.pyplot as plt
# import tracemalloc


import sys
sys.path.insert(0,'../..')


def soft_policy_training_spo_owa_multi(data_reader,
                             validation_data_reader,
                             test_data_reader,
                             model,
                             precompute_quantile=True, 
                             writer=None,
                             experiment_name=None,
                             use_merits=True,
                             args=None):
    ROOT_DIR = 'fair_optim/'

    other_str = "full" if args.fullinfo == "partial" else "partial"
    position_bias_vector = 1. / torch.arange(1.,
                                             100.) ** args.position_bias_power
    lr = args.lr
    num_epochs = args.epochs
    weight_decay = args.weight_decay
    sample_size = args.sample_size
    entropy_regularizer = args.entropy_regularizer

    print('Start training with OWA')
    print("Starting training with the following config")
    print(
        "Batch size {}, Learning rate {}, Weight decay {}, Entropy Regularizer {}, Entreg Decay {} Sample size {}\n"
        "Lambda_reward: {}, lambda_ind_fairness:{}, lambda_group_fairness:{}".
            format(args.batch_size, lr, weight_decay, args.entropy_regularizer,
                   args.entreg_decay, sample_size,
                   args.lambda_reward, args.lambda_ind_fairness,
                   args.lambda_group_fairness))

    if args.gpu:
        print("Use GPU")
        model = model.cuda()
        position_bias_vector = position_bias_vector.cuda()

    optimizer = get_optimizer(model.parameters(), lr, args.optimizer,
                              weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=args.lr_decay, min_lr=1e-6, verbose=True,
        patience=10)

    dataset = args.partial_train_data.split('/')[-3]
    print('dataset', dataset)
    train_feats, train_rels = data_reader
    num_item = train_rels.shape[-1]

    #OWA parameter: 
    w_item = gini_indices(args.multi_groups).double()

    # args.multi_group is the number of groups
    # 10/31


    if args.multi_groups:
        if not precompute_quantile: 
            quantiles = get_group_quantiles(train_feats, args.multi_groups)
        else:
            print('Loading precomputed group quantiles')
            print('datast', dataset)
            if dataset == 'mslr':
                print('load mslr')
                quantiles =  pkl.load(open(os.path.join(ROOT_DIR, 'data/{}/full/quantile_{}group.pkl'.format(dataset,args.multi_groups)), "rb"))
            elif dataset in(['yahoo', 'yahoobinary', 'yahoobinary2']):  
                # quantiles =  pkl.load(open(os.path.join(ROOT_DIR, 'data/{}/{}_lst/quantile_{}group.pkl'.format(dataset,num_item,args.multi_groups)), "rb"))
                print('args group_feat_id', args.group_feat_id)
                quantiles = get_group_quantiles(train_feats, args.multi_groups, num_feature=519, group_feat_id=args.group_feat_id, quantile=0.5)
            elif dataset == 'mslr-web10k': 
                print('load mslr web10k')
                quantiles =  pkl.load(open(os.path.join(ROOT_DIR, 'data/{}/{}_lst/quantile_{}group.pkl'.format(dataset,num_item,args.multi_groups)), "rb"))

        args.group_feat_threshold = quantiles
        print('threshold',args.group_feat_threshold)



    print('before', train_feats.shape)
    train_group_identities = get_group_identities(train_feats, args.group_feat_id, args.group_feat_threshold) 
    train_flag_all_group = [len(g.unique()) == args.multi_groups for g in train_group_identities]
    train_feats, train_rels, train_group_identities = train_feats[train_flag_all_group], train_rels[train_flag_all_group], train_group_identities[train_flag_all_group]
    w_user = torch.ones(train_feats.shape[0]).double()
    train_group_item_mask = create_group_mask_tensor(args.multi_groups, train_group_identities).double()

    if use_merits: 
        train_owa_fp = os.path.join(ROOT_DIR,"data/{}/{}_lst/true_owasol_beta-{}_lb-{}_train_{}group_merit_iter{}_softmax2.pkl".format(dataset, num_item,args.beta, args.lambda_group_fairness, args.multi_groups, args.num_iter))
        print('train_owa_fp', train_owa_fp)
    else: 
        train_owa_fp = os.path.join(ROOT_DIR,"data/{}/{}_lst/true_owasol_beta-{}_lb-{}_train_{}group_iter10k.pkl".format(dataset, num_item,args.beta, args.lambda_group_fairness, args.multi_groups))
    train_sol_true, train_exposure_true, train_owa_true = pkl.load(open(train_owa_fp,"rb"))
    print('after', train_feats.shape)
    print('train_group_identities', train_group_identities.shape)
    print('train_group_item_mask',train_group_item_mask.shape, train_feats.shape, train_rels.shape,train_sol_true.shape, train_exposure_true.shape, train_owa_true.shape)

    train_dataset = torch.utils.data.TensorDataset(train_feats, train_rels,train_group_item_mask, train_sol_true, train_exposure_true, train_owa_true)
    len_train_set = len(train_feats) // args.batch_size + 1
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)

    valid_feats, valid_rels = validation_data_reader
    print('before valid', valid_feats.shape)
    valid_group_identities = get_group_identities(valid_feats, args.group_feat_id, args.group_feat_threshold) 
    flag_all_group = [len(g.unique()) == args.multi_groups for g in valid_group_identities]
    valid_feats, valid_rels, valid_group_identities = valid_feats[flag_all_group], valid_rels[flag_all_group], valid_group_identities[flag_all_group]
    valid_group_item_mask = create_group_mask_tensor(args.multi_groups, valid_group_identities).double()

    validation_dataset = torch.utils.data.TensorDataset(valid_feats, valid_rels, valid_group_item_mask, valid_rels, valid_rels, valid_rels)
    val_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=args.batch_size)
    print('after valid', valid_feats.shape)
 
    test_feats, test_rels = test_data_reader
    print('before test', test_feats.shape)
    test_group_identities = get_group_identities(test_feats, args.group_feat_id, args.group_feat_threshold) 
    flag_all_group = [len(g.unique()) == args.multi_groups for g in test_group_identities]
    test_feats, test_rels, test_group_identities = test_feats[flag_all_group], test_rels[flag_all_group], test_group_identities[flag_all_group]
    test_group_item_mask = create_group_mask_tensor(args.multi_groups, test_group_identities).double()
    print('after test', test_rels)
    test_dataset = torch.utils.data.TensorDataset(test_feats, test_rels,test_group_item_mask, test_rels, test_rels, test_rels)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)

    with torch.no_grad():
        # group0_merit, group1_merit = 1.0, 1.0
        ## turn off group merits
        lst_group_merits = get_group_merits(
            train_feats, train_rels, args.group_feat_id, args.group_feat_threshold, mean=False, use_merits=use_merits)
        print("Group merits: ", lst_group_merits)
        lst_group_merits = torch.tensor(lst_group_merits)


    if args.early_stopping:
        time_since_best = 0
        best_metric = -1e6
        best_model = None
        best_epoch = None

    fairness_loss_list = []
    training_ndcg_list, valid_ndcg_list = [], []
    training_dcg_list, valid_dcg_list = [], []

    training_regrets, valid_regrets = [], []
    valid_criteria, valid_criteria2 = [], []

    training_vio, valid_vio = [], []
    training_step = []
    epoch_iterator = range(num_epochs)

    best_so_far, fails = 0, 0
    patience = 20
    flag_training = False
    JK_best_model = model
    epoch_regrets, epoch_dcg_list = [], []
    total_time, counter = 0 , 0
    fairness_evaluation = True if args.lambda_ind_fairness > 0.0 else False
    group_fairness_evaluation = True

    for epoch in epoch_iterator:
        start = time.time()
        all_exposure = []
        print("Entering training Epoch {}".format(epoch))

        if args.progressbar:
            train_dataloader = tqdm(train_dataloader)
        all_exposures = []
        for batch_id, data in enumerate(train_dataloader):
            start = time.time()
            counter +=1
            step = epoch * len_train_set + batch_id  # JK added here
            print('step', step,)
            feats, rel, group_item_mask, sol_true, exposure_true, owa_true = data
            # print('batsize', batsize)
            #[num_item x 1]: item group identity label
            # group_identities = get_group_identities(feats, args.group_feat_id, args.group_feat_threshold) #if not args.multigroup else get_group_identities(feats, args.group_feat_id, args.group_feat_threshold
            # flag_all_group = [len(g.unique()) == args.multi_groups for g in group_identities]
            # feats, rel, group_identities = feats[flag_all_group], rel[flag_all_group], group_identities[flag_all_group]
            batsize = feats.shape[0]
            w_user = torch.ones(batsize).double()
            true_costs = rel.double()


            #[num_item x 1]: item group identity label
            group_identities = get_group_identities(feats, args.group_feat_id, args.group_feat_threshold) #if not args.multigroup else get_group_identities(feats, args.group_feat_id, args.group_feat_threshold
            optimizer.zero_grad()
            flag_all_group = len(set([len(g.unique()) for g in group_identities])) !=1
            if flag_all_group or batsize==1:
                continue


            # Form the cross product between group a ID embedding and the document scores
            scores = model(feats).squeeze(-1).double()
            pos_bias = ( 1.0 / torch.log2(torch.arange(scores.shape[1]).double() + 2) )
            test_dscts = pos_bias.repeat(batsize,1,1)
            score_cross = torch.bmm( scores.unsqueeze(0).view(batsize,-1,1), test_dscts.view(batsize,1,-1)  ).reshape(batsize,-1)
            cost_coef = torch.stack([torch.cat([score_cross[i].flatten(), torch.zeros(args.list_len, requires_grad=True), torch.tensor([args.lambda_group_fairness], requires_grad=True)]) for i in range(batsize)])

            with torch.no_grad():
                dcg_max = compute_dcg_max(true_costs)  # redundant, defined again below
                V_spo = (2*scores - true_costs)
                if args.batchify: # Note this is for one sample
                    # group_item_mask = create_group_mask_tensor(args.multi_groups, group_identities).double()
                    # print('check ', scores, w_user, w_item, args.num_iter, args.lambda_group_fairness, group_item_mask, args.beta)
                    sol_pred, exposure_pred, final_item_exp , _ = FWS_batch_fast(scores, w_user, w_item, args.num_iter, args.lambda_group_fairness, group_item_mask=group_item_mask, beta=args.beta, merits=lst_group_merits)
                    # sol_true, exposure_true, _, owa_true = FWS_batch_fast(true_costs, w_user, w_item, args.num_iter, args.lambda_group_fairness, group_item_mask=group_item_mask, beta=args.beta, merits=lst_group_merits)
                    # print('check solution: ', torch.abs(sol_true2 - sol_true).sum(dim=-1))
                    # print('check exp', torch.abs(exposure_true2-exposure_true).sum(dim=-1))
                    # print('check owa', torch.abs(owa_true2 - owa_true).sum(dim=-1))
                    sol_spo, exposure_spo, _, owa_spo    = FWS_batch_fast(V_spo, w_user, w_item, args.num_iter, args.lambda_group_fairness, group_item_mask=group_item_mask, beta=args.beta, merits=lst_group_merits)
                    sol_spo_stack = torch.stack([torch.cat([sol_spo[i].flatten(), exposure_spo[i], torch.tensor([owa_spo[i]])]) for i in range(batsize)])
                    sol_true_stack = torch.stack([torch.cat([sol_true[i].flatten(), exposure_true[i], torch.tensor([owa_true[i]])]) for i in range(batsize)])
                    grad = sol_spo_stack - sol_true_stack
                    regrets =torch.einsum("ij, ijk -> ik", true_costs, sol_true - sol_pred)@pos_bias
                    print(regrets.mean())
                    epoch_regrets.extend( regrets.numpy() )

                    dcg_max = compute_dcg_max(true_costs)
                    all_exposures.append(final_item_exp.numpy())
                    test_dscts = ( 1.0 / torch.log2(torch.arange(args.list_len).double() + 2) ).repeat(batsize,1,1)
                    if args.gpu:
                        test_dscts = test_dscts.cuda()
                    loss_a = torch.bmm( sol_pred, test_dscts.view(batsize,-1,1) )
                    loss_b = torch.bmm( true_costs.view(batsize,1,-1), loss_a).squeeze()
                    loss_norm = loss_b.squeeze() / dcg_max
                    loss = loss_norm.mean()
                    # DSM_ndcg_list.append( loss.item() )
                    epoch_dcg_list.append( loss_b.squeeze().mean().item() )

                else: 
                    grad, regrets = [], []
                    for i in range(batsize):
                        group_item_mask = create_group_mask_tensor(args.multi_groups, group_identities[i]).double()
                        sol_pred, exposure_pred, _        = FWS(scores[i].unsqueeze(0), torch.tensor([w_user[i]]), w_item, args.num_iter, args.lambda_group_fairness, group_item_mask=group_item_mask, beta=args.beta)
                        sol_true, exposure_true, owa_true = FWS(true_costs[i].unsqueeze(0), torch.tensor([w_user[i]]), w_item, args.num_iter, args.lambda_group_fairness, group_item_mask=group_item_mask, beta=args.beta)
                        sol_spo, exposure_spo, owa_spo    = FWS(V_spo[i].unsqueeze(0), torch.tensor([w_user[i]]), w_item, args.num_iter, args.lambda_group_fairness, group_item_mask=group_item_mask, beta=args.beta)
                        sol_spo_stack = torch.cat([sol_spo.flatten(), exposure_spo, torch.tensor([owa_spo])])
                        sol_true_stack = torch.cat([sol_true.flatten(), exposure_true, torch.tensor([owa_true])])
                        # all_exposures.append(final_item_exp.numpy())
                        reg =(true_costs[i]@(sol_true - sol_pred))@pos_bias
                        print('i ', i, reg)
                        regrets.append(reg.item())
                        grad.append(sol_spo_stack - sol_true_stack)

                    grad = torch.stack(grad)
                    epoch_regrets.append( regrets.mean().item() )


            if step % args.write_losses_interval == 0:
                # training_regrets.append(np.mean(epoch_regrets))
                print("Evaluating on train set: iteration {}/{} of epoch {}: {}".
                       format(batch_id, len_train_set, epoch, np.mean(epoch_regrets) ))

            if step % args.evaluate_interval == 0:
                all_exposure_train = np.concatenate(all_exposures)
                avg_group_exp = all_exposure_train.mean(1).repeat(args.multi_groups).reshape(-1, args.multi_groups)
                train_fairness_vio_per_group =  np.abs(all_exposure_train - avg_group_exp)
                training_vio.append(train_fairness_vio_per_group.mean())
                # JK do the custom test routine for this policy type
                results = evaluate_soft_model_owa(
                            model,
                            #data_reader,
                            val_dataloader,   # JK switch from eval on train to test data
                            w_user,
                            w_item,
                            deterministic=args.validation_deterministic,
                            fairness_evaluation=fairness_evaluation,
                            num_sample_per_query=args.sample_size,
                            # position_bias_vector=1. / np.log2(2 + np.arange(200)),
                            position_bias_vector=position_bias_vector,
                            group_fairness_evaluation=group_fairness_evaluation,
                            track_other_disparities=args.track_other_disparities,
                            merits=lst_group_merits,
                            args=args)

                print(
                   "Evaluating on validation set:  dcg: {}, fairness mean train: {}, fairness mean val: {}, owa: {}".
                       format(results["DSM_dcg"],train_fairness_vio_per_group.mean(),results['fairness_vio_mean'],results['owa_obj'] ))
                print("SGD lr=%.4f" % (optimizer.param_groups[0]["lr"]))
                crit= args.reward_type

                if crit == 'dcg':
                    criteria = (1-args.lambda_group_fairness)* results["DSM_dcg"] +  args.lambda_group_fairness*results['owa_obj']
                else: 
                    criteria = (1-args.lambda_group_fairness)*results["DSM_ndcg"] +  args.lambda_group_fairness*results['owa_obj']

                print('best_so_far = ')
                print( best_so_far  )
                print('criteria = ')
                print( criteria  )
                # print('training',epoch_regrets) 
                training_dcg_list.append(np.mean(epoch_dcg_list))
                training_regrets.append(np.mean(epoch_regrets))
                valid_criteria.append(criteria)
                valid_criteria2.append(best_so_far)
                valid_dcg_list.append(results['DSM_dcg'])
                valid_vio.append(results['fairness_vio_mean'])
                valid_regrets.append(results['regrets'])

                training_step.append(step)
                scheduler.step(criteria)
                
                # snapshot = tracemalloc.take_snapshot()
                # top_stats = snapshot.statistics('traceback')
                # stat = top_stats[0]
                # print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024))
                # for line in stat.traceback.format():
                #     print(line)    

                if  criteria > ( best_so_far + 1e-5):
                    JK_best_model = copy.deepcopy(model)
                    fails = 0
                    best_so_far = criteria
                    results_valid_best = results.copy()
                else:
                    fails = fails + 1

                if fails > patience:
                    print("Early Stopping. Valid hasn't improved for {}".format(patience))
                    flag_training = True
                #if flag_training:
                #    break


            cost_coef.backward(gradient=grad)
            optimizer.step()
      
        total_time += time.time() - start




        if flag_training:
            break


    print("Entering evalutation of test data")
    results = evaluate_soft_model_owa(
                JK_best_model,
                test_dataloader,
                w_user,
                w_item,
                deterministic=args.validation_deterministic,
                fairness_evaluation=fairness_evaluation,
                num_sample_per_query=args.sample_size,
                # position_bias_vector=1. / np.log2(2 + np.arange(200)),
                position_bias_vector=position_bias_vector,
                group_fairness_evaluation=group_fairness_evaluation,
                track_other_disparities=args.track_other_disparities,
                args=args,
                merits=lst_group_merits,
                is_test= True)

    results_test_best = results.copy()


    # Do a final evaluation on the training set (need fairness quantiles)
    print("Entering evalutation of train data")
    results_train_best = evaluate_soft_model_owa(
                JK_best_model,
                train_dataloader,
                #test_data_reader,   # JK switch from eval on train to test data
                w_user,
                w_item,
                deterministic=args.validation_deterministic,
                fairness_evaluation=fairness_evaluation,
                num_sample_per_query=args.sample_size,
                # position_bias_vector=1. / np.log2(2 + np.arange(200)),
                position_bias_vector=position_bias_vector,
                group_fairness_evaluation=group_fairness_evaluation,
                track_other_disparities=args.track_other_disparities,
                merits=lst_group_merits,
                args=args,)

    csv_outs = {}

    # csv_outs['test_DSM_ndcg_final']  =  results_test_best['DSM_ndcg']
    csv_outs['test_DSM_dcg_final']  =  results_test_best['DSM_dcg']
    csv_outs['test_fairness_vio_mean_final']  =  results_test_best['fairness_vio_mean']
    csv_outs['test_fairness_vio_mean_final2']  =  results_test_best['fairness_vio_mean2']
    csv_outs['test_fairness_vio_min_final']  =  results_test_best['fairness_vio_min']
    csv_outs['test_fairness_vio_max_final']  =  results_test_best['fairness_vio_max']

    # print('check point, ',results_valid_best['fairness_vio_mean'])
    csv_outs['valid_DSM_ndcg_final']  =  results_valid_best['DSM_ndcg']
    csv_outs['valid_DSM_dcg_final']  =  results_valid_best['DSM_dcg']
    csv_outs['valid_fairness_vio_mean_final']  =  results_valid_best['fairness_vio_mean']
    csv_outs['valid_fairness_vio_min_final']  =  results_valid_best['fairness_vio_min']
    csv_outs['valid_fairness_vio_max_final']  =  results_valid_best['fairness_vio_max']

    csv_outs['train_DSM_ndcg_final']  =  results_train_best['DSM_ndcg']
    csv_outs['train_DSM_dcg_final']   =  results_train_best['DSM_dcg']
    csv_outs['train_fairness_vio_mean_final']  =  results_train_best['fairness_vio_mean']
    csv_outs['train_fairness_vio_min_final']  =  results_train_best['fairness_vio_min']
    csv_outs['train_fairness_vio_max_final']  =  results_train_best['fairness_vio_max']

    fair_viols_quantiles_test = results_test_best["fair_viols_quantiles"]
    fair_abs_viols_quantiles_test = results_test_best["fair_abs_viols_quantiles"]

    csv_outs["fair_viol_q_100_test"] = fair_viols_quantiles_test['1.00']
    csv_outs["fair_viol_q_75_test"]  = fair_viols_quantiles_test['0.75']
    csv_outs["fair_viol_q_50_test"]  = fair_viols_quantiles_test['0.50']
    csv_outs["fair_viol_q_25_test"]  = fair_viols_quantiles_test['0.25']
    csv_outs["fair_viol_q_00_test"]  = fair_viols_quantiles_test['0.0']
    csv_outs["fair_abs_viol_q_100_test"] = fair_abs_viols_quantiles_test['1.00']
    csv_outs["fair_abs_viol_q_75_test"]  = fair_abs_viols_quantiles_test['0.75']
    csv_outs["fair_abs_viol_q_50_test"]  = fair_abs_viols_quantiles_test['0.50']
    csv_outs["fair_abs_viol_q_25_test"]  = fair_abs_viols_quantiles_test['0.25']
    csv_outs["fair_abs_viol_q_00_test"]  = fair_abs_viols_quantiles_test['0.0']

    print('fair_viols_quantiles_test', fair_viols_quantiles_test)
    print('fair_abs_viols_quantiles_test', fair_abs_viols_quantiles_test)

    csv_outs["num_iter"] = args.num_iter
    csv_outs["epochs"] = args.epochs
    csv_outs["lr"] = args.lr
    csv_outs["hidden_layer"] = args.hidden_layer
    csv_outs["optimizer"] = args.optimizer
    csv_outs["sample_size"] = args.sample_size
    csv_outs["batch_size"] = args.batch_size
    csv_outs["fairness_gap"] = args.fairness_gap
    csv_outs["index"] = args.index
    csv_outs["seed"]  = args.seed
    csv_outs["dropout"] = args.dropout
    csv_outs["multi_groups"] = args.multi_groups
    csv_outs["beta"] = args.beta
    csv_outs['avg_training_time'] = total_time/(epoch +1)
    csv_outs['total_training_time'] = total_time
    csv_outs['num_item'] = num_item
    csv_outs["dataset"] = args.dataset

    csv_outs = {k:[v] for (k,v) in csv_outs.items()   }
    df_outs = pd.DataFrame.from_dict(csv_outs)
    df_iter = pd.DataFrame({'iter':  training_step,
        'training_vio':training_vio, 'valid_vio': valid_vio, 'training_dcg': training_dcg_list,
        'valid_dcg': valid_dcg_list, 'training_regret': training_regrets, 'valid_regrets': valid_regrets,
        'valid_criteria': valid_criteria, 'valid_criteria2': valid_criteria2,
        })
    df_outs_vio = pd.DataFrame({'test_vio':results_test_best['fair_viols_quantiles_pop']})

    output_tag = '{}{}_hidden-{}_lb-{}_lr-{}_bs-{}_group-{}_beta-{}_iter-{}_seed-{}'.format(args.dataset,num_item, args.hidden_layer, args.lambda_group_fairness, args.lr, args.batch_size, args.multi_groups,args.beta,args.num_iter,args.seed )
    outPathCsv = os.path.join(ROOT_DIR, 'results',  "OWALP_finalres_" +output_tag + '_' + str(args.index)  + "_ver13.csv")
    outPathCsv_iter = os.path.join(ROOT_DIR,'results', "OWALP_iterres_" +output_tag + '_' + str(args.index) + "_ver13.csv")
    outPathCsv_test_vio = os.path.join(ROOT_DIR,'results', "OWALP_test_fairness_vio_" +output_tag + '_' + str(args.index) + "_ver13.csv")
    df_outs.to_csv(outPathCsv)
    print('df_outs', df_outs)
    df_iter.to_csv(outPathCsv_iter)
    print('df_iter', df_iter)
    df_outs_vio.to_csv(outPathCsv_test_vio)
    print('df_outs_vio', df_outs_vio)
    quit()
    # quit()  # JK this is a hack to escape without crashing; curr_metric below is undefined. We have to return something to the main routine.
    return JK_best_model, curr_metric
