import os, argparse
from logging import WARNING
import warnings
warnings.filterwarnings("ignore")

from utils import set_seed
from dataloader import load_dataset_portfolio_multi
from trainer import MultiTrainer
import torch
print("torch.cuda.is_available() = ")
print( torch.cuda.is_available() )
# from owa_optimization import owa_optim_lp_portifolio_wrapper, cvx_qp, owa_optim_lp_cvxpy, owa_optim_lp_norm_cvxpy, owa_pgd_solver_wrapper, gini_indices_square

from owa_optimization import owa_optim_lp_portifolio_wrapper, cvx_qp, owa_optim_lp_cvxpy, owa_optim_lp_norm_cvxpy, owa_pgd_solver_wrapper, foldopt_owa_pgd_solver_wrapper, gini_indices_square
import pickle
import pandas as pd

ROOT_DIR = "fair_optim"

def main(params):

    set_seed(params.seed)

    data_dir = os.path.join(ROOT_DIR, 'data/portfolio/')

    lst_train_loader, lst_test_loader  = load_dataset_portfolio_multi(data_dir, params, data_ver= params.data_ver)

    model_params = {'input_dim': params.input_dim, 
                    'hidden_layer': params.hidden_layer,
                    'num_item': params.num_item, 
                    'dropout': 0.2, 
                    'n_task': params.n_task if params.trainer_name in(['OWA2Stage', 'OWALP', 'OWALPNorm','OWALPNormMoreauGrad', "OWAPGDSubGrad", "OWAPGDMoreauGrad", "OWAFoldedSubGrad"]) else 1, 
    }

    w_gini = gini_indices_square(params.n_task)
    solver = {
        "OWA2Stage": owa_optim_lp_portifolio_wrapper,
        "OWASurrogateQP": cvx_qp(params.num_item, params.eps),
        "OWASurrogateMoreauGrad": cvx_qp(params.num_item, params.eps), ## owa loss with qp surr
        "SumSurrogateQP": cvx_qp(params.num_item, params.eps),
        "OWALP": owa_optim_lp_cvxpy(params.n_task, params.num_item),
        "OWALPNorm": owa_optim_lp_norm_cvxpy(params.n_task, params.num_item, params.eps),
        "OWALPNormMoreauGrad": owa_optim_lp_norm_cvxpy(params.n_task, params.num_item, params.eps),

        "OWAPGDSubGrad":    owa_pgd_solver_wrapper(w_gini,params.num_item, params.beta, params.num_iter, params.gamma, params.use_subgrad),## owa loss with subgradient
        "OWAPGDMoreauGrad": owa_pgd_solver_wrapper(w_gini,params.num_item, params.beta, params.num_iter, params.gamma, params.use_subgrad), ## owa loss with moreau gradient
        "OWAFoldedSubGrad": foldopt_owa_pgd_solver_wrapper(w_gini,params.num_item, params.beta, params.num_iter, params.gamma), ## owa loss with subgradient, folded optim la

    }

    trainer = MultiTrainer(lst_train_loader, lst_test_loader, model_params,solver[params.trainer_name],params)


    epoch_regrets = []
    epoch_losses  = []
    epoch_train_mae_loss, epoch_train_regret, epoch_train_obj, epoch_train_loss_z =[], [], [], []
    epoch_val_mae_loss, epoch_val_regret, epoch_val_obj, epoch_val_loss_z =[], [], [], []
    epoch_val_mae_loss2, epoch_val_regret2, epoch_val_obj2, epoch_val_loss_z2 =[], [], [], []

    results_dump = {}
    results_dump["seed"] = params.seed
    results_dump["index"] = params.index
    results_dump["n_task"] = params.n_task
    results_dump["num_item"] = params.num_item
    results_dump["trainer_name"] = params.trainer_name
    results_dump["eps"] = params.eps
    results_dump["beta"] = params.beta
    results_dump["gamma"] = params.gamma
    results_dump["add_mse"] = params.add_mse
    results_dump["lamb"] = params.lamb
    results_dump['data_ver'] = params.data_ver
    results_dump['use_subgrad'] = params.use_subgrad

    for i in range(params.num_epochs):
        print('='*100)
        train_results = trainer.train_epoch()
        epoch_train_mae_loss  += train_results["train_mae_loss"]
        epoch_train_regret  += train_results["train_regret"]
        epoch_train_obj  += train_results["train_obj"]
        epoch_train_loss_z  += train_results["train_loss"]

        epoch_val_mae_loss  += train_results["val_mae_loss"]
        epoch_val_loss_z  += train_results["val_loss"]
        epoch_val_regret  += train_results["val_regret"]
        epoch_val_obj  += train_results["val_obj"]
        epoch_val_mae_loss2  += train_results["val_mae_loss2"]
        epoch_val_loss_z2  += train_results["val_loss2"]
        epoch_val_regret2  += train_results["val_regret2"]
        epoch_val_obj2  += train_results["val_obj2"]

        if trainer.step >= trainer.patience: 
            print('Reached early stopping. Exit training.')
            break

        # results_dump["epoch_train_mae_loss"]  = epoch_train_mae_loss
        # results_dump["epoch_train_regret"] = epoch_train_regret
        # results_dump["epoch_train_obj"] = epoch_train_obj
        # results_dump["epoch_val_mae_loss"]  = epoch_val_mae_loss
        # results_dump["epoch_val_regret"] = epoch_val_regret
        # results_dump["epoch_val_obj"] = epoch_val_obj
        # results_dump["epoch_val_loss_z"] = epoch_val_loss_z
        # results_dump["epoch_val_mae_loss2"]  = epoch_val_mae_loss2
        # results_dump["epoch_val_regret2"] = epoch_val_regret2
        # results_dump["epoch_val_obj2"] = epoch_val_obj2
        # results_dump["epoch_val_loss_z2"] = epoch_val_loss_z2

        # out_fp = os.path.join(ROOT_DIR,'results/multi_objectives', "portfolio_syn_equalratio_{}_{}_{}_{}_{}_{}_{}_{}_{}_ver{}.p".format(params.trainer_name,params.hidden_layer, params.lr, params.n_task,params.eps, params.beta, params.num_iter, params.seed, params.index, data_ver))        # out_fp = '/home/mydinh/Fair_LTR_enhanced/blackbox-differentiation-combinatorial-solvers/p/wcsp_pickle_'+params.problem_type+"_"+params.trainer_name+"_"+ str(params.loader_params.n_task)+ '_'+str(params.loader_params.task_idx  )+ '_' +str(params.trainer_params.beta)+"_"+ str(params.index) +"_"+ str(params.seed) +'.p'
        # pickle.dump(results_dump, open(out_fp, 'wb'))
        # print('save results files to: ', out_fp)

    torch.save(trainer.model.state_dict(), "combres1.pt")
    print('model saved')

    eval_results = trainer.evaluate(is_test=True)
    print(eval_results)



    results_dump["epoch_train_mae_loss"]  = epoch_train_mae_loss
    results_dump["epoch_train_regret"] = epoch_train_regret
    results_dump["epoch_train_obj"] = epoch_train_obj
    results_dump["epoch_val_mae_loss"]  = epoch_val_mae_loss
    results_dump["epoch_val_regret"] = epoch_val_regret
    results_dump["epoch_val_obj"] = epoch_val_obj
    results_dump["epoch_val_loss_z"] = epoch_val_loss_z
    results_dump["epoch_val_mae_loss2"]  = epoch_val_mae_loss2
    results_dump["epoch_val_regret2"] = epoch_val_regret2
    results_dump["epoch_val_obj2"] = epoch_val_obj2
    results_dump["epoch_val_loss_z2"] = epoch_val_loss_z2

    results_dump['val_z_opt'] = eval_results['z_opt']
    results_dump['best_z_owa'] = eval_results['owa_obj']
    results_dump['best_mae_loss'] = eval_results['mae_loss']
    results_dump['best_regret'] = eval_results['regret']
    results_dump['best_z_loss'] = eval_results['loss_z']
    suffix = 'ver1'
    out_fp = os.path.join(ROOT_DIR,'results/multi_objectives', "portfolio{}_equalratio_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.p".format(params.data_ver, params.trainer_name,params.hidden_layer, params.lr, params.n_task,params.eps, params.beta, params.num_iter, params.gamma, params.add_mse, params.lamb, params.use_subgrad, params.seed, params.index, suffix))

    pickle.dump(results_dump, open(out_fp, 'wb'))
    print('save results files to: ', out_fp)



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Multitask training script')

    # Model parameters.
    #training config
    parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')
    parser.add_argument('--num_epochs', type=int, default=30)
    parser.add_argument('--index', type=int, default=0,  help='name of experience', )
    parser.add_argument('--trainer_name', type=str, default='OWA2Stage',help='Name of Experiments',  
                    choices=['OWA2Stage', 'OWASurrogateQP','OWALP','OWALPNorm','OWALPNormMoreauGrad','OWASurrogateMoreauGrad',
                    'SumSurrogateQP', "OWAPGDSubGrad", "OWAPGDMoreauGrad", "OWAFoldedSubGrad"])
    parser.add_argument('--add_mse', type=int, default=0,  help='whether to add MSE loss. OWA2Stage and Surrogated model has this default =0', )
    parser.add_argument('--data_ver', type=str, default='_syn1',help='version of data')

    #solver config
    parser.add_argument('--n_task', type=int, default=5,
                    help='number of task')
    parser.add_argument('--beta', type=float,default=1e-4,
                    help="OWA's smoothing parameter")
    parser.add_argument('--eps', type=float, default=0.5,help='regulization weight of cp layer')
    parser.add_argument('--num_iter', type=int, default=1000,help='# of PGD iterations')
    parser.add_argument('--gamma', type=float, default=0.02,help='PGD step size')
    parser.add_argument('--use_subgrad', type=int, default=0,help='only for pgd based model. ')


    #model config
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--input_dim", type=int, default=50)
    parser.add_argument("--num_item", type=int, default=20)
    parser.add_argument('--use_cuda', type=int, default=0, help='whether to use cuda',)
    parser.add_argument("--hidden_layer", type=int, default=3)
    parser.add_argument('--lr', type=float, default=0.01, help='model learning rate')
    parser.add_argument('--lamb', type=float, default=0, help='weight of MSE Loss. Set it =0 if OWA2Stage and for Surrogate Model ')

    args = parser.parse_args()

    main(args)
