import torch
from utils import shuffle_combined
import numpy as np
from YahooDataReader import YahooDataReader
from models import NNModel, LinearModel        , MLP  # JK 11/11
from utils import parse_my_args_reinforce, torchify,      transform_dataset   #JK
from evaluation import evaluate_model   # JK argpase error (argument not found) happens here
from baselines import vvector
from progressbar import progressbar
# JK new
from parse_args import args
import ast, os, time
from datareader import reader_from_pickle
from fairness_loss import  get_group_identities, get_group_merits, get_group_quantiles
import copy
import pickle as pkl
import pandas as pd


def demographic_parity_train(model, dr, vdr, tdr, vvector, args, group0_merit, group1_merit):


    print('group0_merit', group0_merit)
    print('group1_merit', group1_merit)



   
    feat, rel = dr    # dr.data    # JK 11/12
    #feat, rel = shuffle_combined(feat, rel)

    flag_training = False
    JK_best_model = model
    patience = 15
    best_so_far = -8000
    #
    num_item = rel.shape[-1]


    dataset = args.dataset
    print('Loading precomputed group quantiles')
    print('ROOT_DIR')
    print('dataset', dataset)
    # dataset = 'mslr-web10k'
    # dataset = args.dataset
    if dataset == 'mslr':
        quantiles =  pkl.load(open(os.path.join(ROOT_DIR, 'data/{}/full/quantile_{}group.pkl'.format(dataset,2)), "rb"))
    elif dataset in(['yahoo', 'yahoobinary', 'yahoobinary2']):  
        # quantiles =  pkl.load(open(os.path.join(ROOT_DIR, 'data/{}/{}_lst/quantile_{}group.pkl'.format(dataset,num_item,args.multi_groups)), "rb"))
        print('args group_feat_id', args.group_feat_id)
        quantiles = get_group_quantiles(feat, 2, num_feature=519, group_feat_id=args.group_feat_id)
    elif dataset == 'mslr-web10k': 
        print('dataset', dataset, os.path.join(ROOT_DIR, 'data/{}/{}_lst/quantile_{}group.pkl'.format(dataset,num_item,2)))
        quantiles =  pkl.load(open(os.path.join(ROOT_DIR, 'data/{}/{}_lst/quantile_{}group.pkl'.format(dataset,num_item,2)), "rb"))
    if dataset == 'mslr-web10k': 
        #M 12/14: try hard code from old threshold
        args.group_feat_threshold = 0.03252032399177551 
    else: 
        args.group_feat_threshold = quantiles

    print('threshold',args.group_feat_threshold)


    train_group_identities = get_group_identities(feat, args.group_feat_id, args.group_feat_threshold) 
    train_flag_all_group = [len(g.unique()) == 2 for g in train_group_identities]

    feat, rel = feat[train_flag_all_group], rel[train_flag_all_group]
    train_dataset = torch.utils.data.TensorDataset(feat, rel)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    print('feat', feat.shape)
    len_train_set = len(feat) // args.batch_size + 1

    valid_feats, valid_rels = vdr
    valid_group_identities = get_group_identities(valid_feats, args.group_feat_id, args.group_feat_threshold)
    flag_all_group = [len(g.unique()) == 2 for g in valid_group_identities]
    valid_feats, valid_rels = valid_feats[flag_all_group], valid_rels[flag_all_group]
    vdr = (valid_feats,valid_rels)

    test_feats, test_rels = tdr
    test_group_identities = get_group_identities(test_feats, args.group_feat_id, args.group_feat_threshold)
    flag_all_group = [len(g.unique()) == 2 for g in test_group_identities]
    test_feats, test_rels = test_feats[flag_all_group], test_rels[flag_all_group]
    tdr = (test_feats,test_rels)


    with torch.no_grad():
        if dataset == 'mslr-web10k': 
            group0_merit = 1.91391408523019
            group1_merit = 2.5832905470933882
        else: 
        
            group0_merit, group1_merit = get_group_merits(
                feat, rel, args.group_feat_id, args.group_feat_threshold, mean=False)
            print("Group 0 mean merit: {}, Group1 mean merit: {}".format(
                group0_merit, group1_merit))
            sign = 1.0 if group0_merit >= group1_merit else -1.0



    N = len(rel)
    from utils import get_optimizer
    optimizer = get_optimizer(
        model.parameters(),
        args.lr,    #[0],   # JK 11/11
        args.optimizer,
        weight_decay=0.0)   #      args.weight_decay[0])  # JK 11/11    0.0 is default
    total_time= 0 
    counter= 0
    for epoch in range(args.epochs):    #[0]):   # JK 11/11
        for batch_id, data in enumerate(train_dataloader):
            counter+=1

            #print("Entering batch {}".format(batch_id))
            start_time = time.time()
            feats, rel = data


            #feat, rel = shuffle_combined(feat, rel)    # JK 11/12 moved outside


            optimizer.zero_grad()
            #curr_feats = feat[i]
            curr_feats = feats   # JK 11/12
            scores = model(curr_feats).squeeze()
            #scores = model(torchify(curr_feats)).squeeze()   # JK 11/11

            #probs = torch.nn.Softmax(dim=0)(scores)

            probs = torch.nn.Softmax(dim=1)(scores)  # JK 11/11

            #if rel[i].sum() == 0:       #np.sum(rel[i]) == 0:  # JK
            #    continue
            #normalized_rels = rel[i]  # / np.sum(rel[i])
            normalized_rels = torch.nn.Softmax(dim=1)(rel)   # JK 11/11


            # np.random.shuffle(normalized_rels)

            #ranking_loss = -torch.sum(
            #    torch.FloatTensor(normalized_rels) * torch.log(probs))
            ranking_loss = -torch.sum(
                normalized_rels * torch.log(probs))     # JK

            # print(scores, probs,
            #       torch.log(probs), normalized_rels,
            #       torch.log(probs) * torch.FloatTensor(normalized_rels),
            #       ranking_loss)

            exposures = vvector[0] * probs
            #groups = curr_feats[:, args.group_feat_id]

            groups = get_group_identities(
                curr_feats, args.group_feat_id, args.group_feat_threshold)

            #if np.all(groups == 0) or np.all(groups == 1):   #JK 11/11
            if (groups == 0).all() or (groups == 1).all():
                fairness_loss = 0.0
            else:
                avg_exposure_0 = torch.sum(
                    torch.FloatTensor(1 - groups) * exposures) / torch.sum(
                        1 - torch.FloatTensor(groups))
                avg_exposure_1 = torch.sum(
                    torch.FloatTensor(groups) * exposures) / torch.sum(
                        torch.FloatTensor(groups))
                # print(avg_exposure_0, avg_exposure_1)
                fairness_loss = torch.pow(
                    torch.clamp(avg_exposure_1 - avg_exposure_0, min=0), 2)
            #loss = args.lambda_reward * ranking_loss + args.lambda_group_fairness * fairness_loss
            loss = 1.0 * ranking_loss + args.lambda_group_fairness * fairness_loss    # JK 11/11

            #print("loss = {}      {}".format(ranking_loss,args.lambda_group_fairness * fairness_loss ))

            loss.backward()
            optimizer.step()
            total_time += time.time() - start_time

            # break

        # end of epoch
        #if i % args.evaluate_interval == 0 and i != 0:  # JK 11/11
        results = evaluate_model(
            model,
            vdr,
            group0_merit = group0_merit,
            group1_merit = group1_merit,
            fairness_evaluation=False,
            group_fairness_evaluation=True,
            deterministic=True,
            args=args,
            num_sample_per_query=100)
        #print(results)
        print( results['ndcg'] )
        print( results['avg_group_disparity'] )


        valid_ndcg_final = results["ndcg"]      # JK evaluation.py line 504 for origin of these
        valid_dcg_final  = results["dcg"]
        valid_rank_final = results["avg_rank"]
        #if group_fairness_evaluation:
        valid_abs_group_expos_disp_final = results["avg_abs_group_disparity"]
        valid_group_expos_disp_final = results["avg_group_disparity"]
        valid_group_asym_disp_final = results["avg_group_asym_disparity"]
        fair_viols_quantiles_valid = results["fair_viols_quantiles"]
        # JK end test metric collection

        #valid_ndcg_list_plot.append( valid_ndcg_final )
        #valid_viol_list_plot.append( valid_group_asym_disp_final )

        stop_metric = 1.0 * valid_dcg_final
        if args.lambda_group_fairness > 0:
            stop_metric -= args.lambda_group_fairness * valid_group_asym_disp_final

        if  stop_metric > ( best_so_far + 1e-3):
            JK_best_model = copy.deepcopy(model)
            time_since_best = 0
            best_so_far = stop_metric
            results_valid_best = results.copy()
        else:
            time_since_best = time_since_best + 1

        print("time_since_best = {}".format(time_since_best))

        if time_since_best > patience:
            print("Early Stopping. Valid hasn't improved for {}".format(patience))
            flag_training = True

        if flag_training:
            break







    valid_ndcg_final = results_valid_best["ndcg"]
    valid_dcg_final  = results_valid_best["dcg"]
    valid_rank_final = results_valid_best["avg_rank"]
    #if group_fairness_evaluation:
    valid_abs_group_expos_disp_final = results_valid_best["avg_abs_group_disparity"]
    valid_group_expos_disp_final = results_valid_best["avg_group_disparity"]
    valid_group_asym_disp_final = results_valid_best["avg_group_asym_disparity"]
    fair_viols_quantiles_valid = results_valid_best["fair_viols_quantiles"]



    results = evaluate_model(
        JK_best_model,
        tdr,
        fairness_evaluation=False,
        group_fairness_evaluation=True,
        deterministic=True,
        args=args,
        num_sample_per_query=100,
        group0_merit = group0_merit,
        group1_merit = group1_merit
        )
    test_ndcg_final = results["ndcg"]      # JK evaluation.py line 504 for origin of these
    test_dcg_final  = results["dcg"]
    test_rank_final = results["avg_rank"]
    #if group_fairness_evaluation:
    test_abs_group_expos_disp_final = results["avg_abs_group_disparity"]
    test_group_expos_disp_final = results["avg_group_disparity"]
    test_group_asym_disp_final = results["avg_group_asym_disparity"]
    fair_viols_quantiles_test  = results["fair_viols_quantiles"]
    fair_abs_viols_quantiles_test  = results["fair_abs_viols_quantiles"]

    csv_outs = {}
    #csv_outs['entropy_final']  =  entropy_writelist_JK[-1]
    #csv_outs["rewards_final"]  =  rewards_writelist_JK[-1]
    #if args.lambda_group_fairness != 0.0:
    #    csv_outs["fairness_loss_final"] =  fairness_loss_writelist_JK[-1]
    #    csv_outs["max_fairness_loss_final"] =  max_fairness_loss_writelist_JK[-1]
    #csv_outs["reward_variance_final"] = reward_variance_writelist_JK[-1]
    #csv_outs["train_ndcg_final"] = train_ndcg_final
    #csv_outs["train_dcg_final"] = train_dcg_final
    #csv_outs["train_rank_final"] = train_rank_final
    #csv_outs["train_abs_group_expos_disp_final"] = train_abs_group_expos_disp_final
    #csv_outs["train_group_expos_disp_final"] = train_group_expos_disp_final
    #csv_outs["train_group_asym_disp_final"] = train_group_asym_disp_final
    csv_outs["test_ndcg_final"] = test_ndcg_final
    csv_outs["test_dcg_final"] = test_dcg_final
    csv_outs["test_rank_final"] = test_rank_final
    csv_outs["test_abs_group_expos_disp_final"] = test_abs_group_expos_disp_final
    csv_outs["test_group_expos_disp_final"] = test_group_expos_disp_final
    csv_outs["test_group_asym_disp_final"] = test_group_asym_disp_final
    csv_outs["valid_ndcg_final"] = valid_ndcg_final
    csv_outs["valid_dcg_final"] = valid_dcg_final
    csv_outs["valid_rank_final"] = valid_rank_final
    csv_outs["valid_abs_group_expos_disp_final"] = valid_abs_group_expos_disp_final
    csv_outs["valid_group_expos_disp_final"] = valid_group_expos_disp_final
    csv_outs["valid_group_asym_disp_final"] = valid_group_asym_disp_final

    # csv_outs["fair_viol_q_100_test"] = fair_viols_quantiles_test['1.00']
    # csv_outs["fair_viol_q_95_test"]  = fair_viols_quantiles_test['0.95']
    # csv_outs["fair_viol_q_90_test"]  = fair_viols_quantiles_test['0.90']
    # csv_outs["fair_viol_q_85_test"]  = fair_viols_quantiles_test['0.85']
    # csv_outs["fair_viol_q_80_test"]  = fair_viols_quantiles_test['0.80']
    # csv_outs["fair_viol_q_75_test"]  = fair_viols_quantiles_test['0.75']
    # csv_outs["fair_viol_q_70_test"]  = fair_viols_quantiles_test['0.70']
    # csv_outs["fair_viol_q_65_test"]  = fair_viols_quantiles_test['0.65']
    # csv_outs["fair_viol_q_60_test"]  = fair_viols_quantiles_test['0.60']
    # csv_outs["fair_viol_q_55_test"]  = fair_viols_quantiles_test['0.55']
    # csv_outs["fair_viol_q_50_test"]  = fair_viols_quantiles_test['0.50']
    # csv_outs["fair_viol_q_100_valid"] = fair_viols_quantiles_valid['1.00']
    # csv_outs["fair_viol_q_95_valid"]  = fair_viols_quantiles_valid['0.95']
    # csv_outs["fair_viol_q_90_valid"]  = fair_viols_quantiles_valid['0.90']
    # csv_outs["fair_viol_q_85_valid"]  = fair_viols_quantiles_valid['0.85']
    # csv_outs["fair_viol_q_80_valid"]  = fair_viols_quantiles_valid['0.80']
    # csv_outs["fair_viol_q_75_valid"]  = fair_viols_quantiles_valid['0.75']
    # csv_outs["fair_viol_q_70_valid"]  = fair_viols_quantiles_valid['0.70']
    # csv_outs["fair_viol_q_65_valid"]  = fair_viols_quantiles_valid['0.65']
    # csv_outs["fair_viol_q_60_valid"]  = fair_viols_quantiles_valid['0.60']
    # csv_outs["fair_viol_q_55_valid"]  = fair_viols_quantiles_valid['0.55']
    # csv_outs["fair_viol_q_50_valid"]  = fair_viols_quantiles_valid['0.50']

    csv_outs["fair_viol_q_100_test"] = fair_viols_quantiles_test['1.00']
    csv_outs["fair_viol_q_75_test"]  = fair_viols_quantiles_test['0.75']
    csv_outs["fair_viol_q_50_test"]  = fair_viols_quantiles_test['0.50']
    csv_outs["fair_viol_q_25_test"]  = fair_viols_quantiles_test['0.25']
    csv_outs["fair_viol_q_00_test"]  = fair_viols_quantiles_test['0.0']
    csv_outs["fair_abs_viol_q_100_test"] = fair_abs_viols_quantiles_test['1.00']
    csv_outs["fair_abs_viol_q_75_test"]  = fair_abs_viols_quantiles_test['0.75']
    csv_outs["fair_abs_viol_q_50_test"]  = fair_abs_viols_quantiles_test['0.50']
    csv_outs["fair_abs_viol_q_25_test"]  = fair_abs_viols_quantiles_test['0.25']
    csv_outs["fair_abs_viol_q_00_test"]  = fair_abs_viols_quantiles_test['0.0']

    csv_outs["stop_epoch"] = epoch




    csv_outs["index"] = args.index
    csv_outs["epochs"] = args.epochs
    csv_outs["lr"] = args.lr
    csv_outs["hidden_layer"] = args.hidden_layer
    csv_outs["optimizer"] = args.optimizer
    csv_outs["quad_reg"] = args.quad_reg
    csv_outs["partial_train_data"] = args.partial_train_data
    csv_outs["partial_val_data"] = args.partial_val_data
    csv_outs["full_test_data"] = args.full_test_data
    csv_outs["log_dir"] = args.log_dir
    csv_outs["sample_size"] = args.sample_size
    csv_outs["batch_size"] = args.batch_size
    csv_outs["soft_train"] = args.soft_train
    csv_outs["disparity_type"] = args.disparity_type
    csv_outs["lambda_group_fairness"] = args.lambda_group_fairness
    csv_outs["index"] = args.index
    csv_outs["dropout"] = args.dropout
    csv_outs["gme_new"] = args.gme_new
    csv_outs['avg_training_time'] = total_time/counter
    csv_outs['total_training_time'] = total_time
    csv_outs['num_item'] = num_item
    csv_outs["dataset"] = args.dataset


    csv_outs = {k:[v] for (k,v) in csv_outs.items()   }
    df_outs = pd.DataFrame.from_dict(csv_outs)
    output_tag = '{}{}_hidden-{}_lb-{}_lr-{}_bs-{}_group-{}_seed-{}'.format(args.dataset,num_item, args.hidden_layer, args.lambda_group_fairness, args.lr, args.batch_size, args.multi_groups,args.seed )
    

    outPathCsv_test_vio = os.path.join(ROOT_DIR, 'results',  "listwise_test_fairness_vio_" +output_tag + '_' + str(args.index)  + "_plotting.csv")
    df_outs_vio = pd.DataFrame({'test_vio':results['fair_viol_all_list']})
    df_outs_vio.to_csv(outPathCsv_test_vio)
    print('df_outs_vio', df_outs_vio)

    outPathCsv = os.path.join(ROOT_DIR, 'results',  "listwise_finalres_" +output_tag + '_' + str(args.index)  + "_plotting.csv")
    df_outs.to_csv(outPathCsv)
    print(df_outs)



    for (k,v) in csv_outs.items():
        print("{}:  {}".format(k,v))

    print("Outputs saved")



    quit()
    return JK_best_model


class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)


if __name__ == "__main__":


    # JK 11/11
    if args.disparity_type == 'disp0':
        group0_merit = 1.0     # TODO write these in for the different datasets
        group1_merit = 1.0
    else:
        if 'mslr' in args.partial_train_data.lower():
            group0_merit = 1.91391408523019
            group1_merit = 2.5832905470933882
        else:
            group0_merit = 3.1677021123091107
            group1_merit = 1.415066736141729
    #


    train_data = reader_from_pickle(args.partial_train_data)
    train_data = train_data.data
    dr = transform_dataset(train_data, args.gpu, args.weighted)

    valid_data = reader_from_pickle(args.full_test_data)
    valid_data = valid_data.data
    vdr = transform_dataset(valid_data, args.gpu, True)  # JK new, previously done below (commented)

    test_data = reader_from_pickle(args.full_test_data)
    test_data = test_data.data
    tdr = transform_dataset(test_data, args.gpu, True)  # JK new, previously done below (commented)


    # JK 11/11   load lambda
    #a = ast.literal_eval(args.lambda_list)
    #lambdas_list = [float(c) for c in a]
    #args.lambda_group_fairness = lambdas_list[-1]
    #


    kwargs = {'clamp': args.clamp}
    if args.mask_group_feat:
        kwargs['masked_feat_id'] = args.group_feat_id
    model = MLP(args.input_dim, args.hidden_layer, args.dropout, **kwargs)

    model = demographic_parity_train(model, dr, vdr, tdr, vvector(200), args, group0_merit, group1_merit)
