#!/usr/bin/env python
# coding: utf-8

import numpy as np
import torch
import os
import time
from tensorboardX import SummaryWriter
import ast

from parse_args import args
from datareader import reader_from_pickle
#from train2 import on_policy_training, soft_policy_training, soft_policy_training_spo, soft_policy_training_spo_multi, soft_policy_training_qp, soft_policy_training_bb, soft_policy_training_int, soft_policy_training_twostage # JK
from train2 import on_policy_training, soft_policy_training_spo, soft_policy_training_spo_multi # JK
from train_owa import soft_policy_training_spo_owa_multi
from train_owa_pgd import monreau_smoothing_owa_training
from models import LinearModel, MLP, MLPGroupEmbedding, SiameseMLP, MLPQuadScore # JK
from evaluation import evaluate_model
from utils import serialize, transform_dataset, unserialize
from models import LinearModel, init_weights
from zehlike import demographic_parity_train
from baselines import vvector

if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    #torch.set_num_threads(args.num_cores)  # JK 0805


    args.lambda_list = "[0.0]"   # JK 11/13

    print('args, ', args)


    if args.seed != 9999:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
    
    train_data = reader_from_pickle(
        args.partial_train_data) if args.fullinfo == "partial" else reader_from_pickle(args.full_train_data)
    train_data = train_data.data # tuple [train_features(num_query x num_item x feature_dim), rel_score (Num_query x num_item)]
    train_data = transform_dataset(
        train_data, args.gpu, args.weighted)

    val_data = reader_from_pickle(
        args.partial_val_data) if args.fullinfo == "partial" else reader_from_pickle(args.full_val_data)
    val_data = val_data.data
    val_data = transform_dataset(
        val_data, args.gpu, args.weighted)


    test_data = reader_from_pickle(args.full_test_data)
    test_data = test_data.data
    test_data = transform_dataset(test_data, args.gpu, True)  # JK new, previously done below (commented)

    if args.summary_writing:
        if not os.path.exists(args.log_dir):
            try:
                os.makedirs(args.log_dir)
            except FileExistsError:
                pass
        writer = SummaryWriter(args.log_dir)
    else:
        writer = None

    model_params_list = []

    a = ast.literal_eval(args.lambda_list)

    print("a = ")
    print( a )
    lambdas_list = [float(c) for c in a]
    plt_data = []
    plt_data_dict = []

    if not os.path.exists(args.hyperparam_folder):
        os.makedirs(args.hyperparam_folder)

    for i, lgroup in enumerate(lambdas_list):
        args.lambda_reward = 1.0
        args.lambda_ind_fairness = 0.0
        #args.lambda_group_fairness = lgroup

        wd = args.weight_decay
        er = args.entropy_regularizer

        lgroup_name = str(
            lgroup) if lgroup >= 0.01 else "{:.1e}".format(lgroup)
        experiment_name = "{}_{}_lambda{}_lr{}_wd{}_er{}_ed{}".format(
            args.experiment_prefix, args.fullinfo, lgroup_name, args.lr,
            args.weight_decay, args.entropy_regularizer, args.entreg_decay)

        model_kwargs = {'clamp': args.clamp}
        if args.mask_group_feat:
            model_kwargs['masked_feat_id'] = args.group_feat_id
        if args.hidden_layer is None:
            model = LinearModel(
                input_dim=args.input_dim, **model_kwargs)
        else:
            print('num_item ',test_data[1].shape[-1])
            if args.mode in [ 'policy_grad', 'listwise']: 

                bn=False
                print('bn1', bn)
            else: 
                bn=True
            model = MLP(input_dim=args.input_dim,
                        hidden_layer=args.hidden_layer,
                        num_item=test_data[1].shape[-1],
                        dropout=args.dropout, 
                        use_bn = bn, 
                        **model_kwargs)
            # JK
            print('model', model)
            # if args.soft_train:
            use_init = True
            if use_init:
                init_weights(model, 'xavier')


        # JK
        if args.embed_groups:
            model = SiameseMLP( model, MLPGroupEmbedding(input_dim=args.list_len,
                                                         hidden_layer=1,    # JK static choice for now
                                                         dropout=args.dropout, **model_kwargs) )
        elif args.embed_quadscore:
            model = MLPQuadScore(model, list_len=args.list_len)


        # JK
        #result = on_policy_training(
        #    train_data, val_data, model, writer=writer,
        #    experiment_name=experiment_name, args=args)
        if (not args.soft_train) | (args.mode =='policy_grad'):
            print('policy_grad')
            result = on_policy_training(
                train_data, val_data, test_data, model, writer=writer,
                experiment_name=experiment_name, args=args)

        elif args.soft_train == 1:
            #result = soft_policy_training(
            if args.multi_groups != 0:
                if args.mode == 'spo': 
                    result = soft_policy_training_spo_multi(
                        train_data, val_data, test_data, model, writer=writer,
                        experiment_name=experiment_name, args=args)
                elif args.mode == 'owa_lp':
                    result = soft_policy_training_spo_owa_multi(
                                            train_data, val_data, test_data, model, writer=writer,
                                            experiment_name=experiment_name, args=args)
                elif args.mode == 'monreau_owa': 
                    result = monreau_smoothing_owa_training(
                                            train_data, val_data, test_data, model, writer=writer,
                                            experiment_name=experiment_name, args=args)

            else:
                if args.mode == 'spo':
                    print('spo training')
                    result = soft_policy_training_spo(
                        train_data, val_data, test_data, model, writer=writer,
                        experiment_name=experiment_name, args=args)
                elif args.mode == 'qp':
                    result = soft_policy_training_qp(
                        train_data, val_data, test_data, model, writer=writer,
                        experiment_name=experiment_name, args=args)
                elif args.mode == 'bb':
                    result = soft_policy_training_bb(
                        train_data, val_data, test_data, model, writer=writer,
                        experiment_name=experiment_name, args=args)
                elif args.mode == 'int':
                    result = soft_policy_training_int(
                        train_data, val_data, test_data, model, writer=writer,
                        experiment_name=experiment_name, args=args)
                elif args.mode == 'twostage':
                    result = soft_policy_training_twostage(
                        train_data, val_data, test_data, model, writer=writer,
                        experiment_name=experiment_name, args=args)
                elif args.mode == 'owa_lp': 
                    result = soft_policy_training_spo_owa_multi(
                                            train_data, val_data, test_data, model, writer=writer,
                                            experiment_name=experiment_name, args=args)

                else:
                    print("Invalid training mode chosen")

        if args.mode == 'listwise':
        # elif (args.soft_train == 2) | (args.mode =='listwise'):
            print('do zehlike')
            if args.disparity_type == 'disp0':
                group0_merit = 1.0     # TODO write these in for the different datasets
                group1_merit = 1.0
            else:
                if 'mslr' in args.partial_train_data.lower():
                    group0_merit = 1.91391408523019
                    group1_merit = 2.5832905470933882
                else:
                    group0_merit = 3.1677021123091107
                    group1_merit = 1.415066736141729
            #a = ast.literal_eval(args.lambda_list)
            #lambdas_list = [float(c) for c in a]
            #args.lambda_group_fairness = lambdas_list[-1]
            #
            kwargs = {'clamp': args.clamp}
            if args.mask_group_feat:
                kwargs['masked_feat_id'] = args.group_feat_id
            # model = MLP(args.input_dim, args.hidden_layer, args.dropout, **kwargs)
            result = demographic_parity_train(model, train_data, val_data, test_data, vvector(200), args, group0_merit, group1_merit)


        model, performance = result
        print(model)
        print("Get best performance {} at weight decay {}, entropy_regularizer {}".format(
                    performance, wd, er))

        #test_data = transform_dataset(test_data, args.gpu, True)   JK moved this up
        results_test = evaluate_model(
            model, test_data, fairness_evaluation=False,
            group_fairness_evaluation=True, track_other_disparities=True,
            deterministic=args.evaluation_deterministic,
            args=args, num_sample_per_query=args.sample_size, normalize=True,
            noise=False, en=0.0)
        print("Best performance on valid set: {}".format(performance))
        out_dict = {'best_perf': performance, "test": results_test, 'args': vars(args)}
        if args.eval_weighted_val:
            weighted_validation_data_reader = reader_from_pickle(args.eval_weighted_val_location)
            weighted_validation_data = transform_dataset(weighted_validation_data_reader.data, args.gpu, True)
            results_validation = evaluate_model(model, weighted_validation_data, fairness_evaluation=False,
                                                group_fairness_evaluation=True, track_other_disparities=True,
                                                deterministic=args.evaluation_deterministic, args=args,
                                                num_sample_per_query=args.sample_size, normalize=True, noise=False,
                                                en=0.0)
            out_dict['valid'] = results_validation

        if args.eval_other_train:
            other_train_data_reader = reader_from_pickle(args.eval_other_train_location)
            other_train_data = transform_dataset(other_train_data_reader.data, args.gpu, True)
            results_train = evaluate_model(model, other_train_data, fairness_evaluation=False,
                                           group_fairness_evaluation=True, track_other_disparities=True,
                                           deterministic=args.evaluation_deterministic, args=args,
                                           num_sample_per_query=args.sample_size, normalize=True, noise=False,
                                           en=0.0)
            out_dict['train'] = results_train
        out_dict.update({
            "gf_lambda": lgroup,
            "weight_decay": wd,
            "entropy_regularizer": er,
            "early_stopping": args.early_stopping,
            "full_info": args.fullinfo,
            "learning_rate": args.lr,
            "performance": performance
        })
        plt_data_dict.append(out_dict)
        if args.save_checkpoints:
            torch.save(model, os.path.join(
                args.hyperparam_folder, "best_{}_{}_lr{}_wd{}_er{}_es{}.ckpt".format(
                    args.fullinfo, lgroup, args.lr, wd, er,
                    args.early_stopping)))
    serialize(
        plt_data_dict, os.path.join(
            args.hyperparam_folder, 'plt_data_pl_{}_{}_tune{}_{}.json'.format(
                lambdas_list, args.fullinfo, args.tuning,
                time.strftime("%m-%d-%H-%M"))),
        in_json=True)
    if writer is not None:
        writer.close()