import argparse
from config import *

def arg_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--algorithm', help = 'name of algorithm',
                        type = str, choices = OPTIMIZERS, default = 'erm')
    parser.add_argument('--data', help = 'name of dataset',
                        type = str, default = 'enem')
    parser.add_argument('--data_setting', help = 'operations on raw datas to generate data set; input a dict',
                    type = dict, default = {'sensitive_attr': 'race', 'generate': False})
    parser.add_argument('--fairness_constraints', help = 'fairness constraints: dp, eo, eopp, range from 0 to 1',
                        type = dict, default = {'metric':'dp', 'bound':0.01})
    parser.add_argument('--seed', help = 'operations on raw datas to generate decentralized data set; input a dict',
                        type = int, default = 123)
    
    # model
    parser.add_argument('--model', help = 'name of model;',
                        type = str, choices = MODELS, default = '1nn')
    parser.add_argument('--loss', help = 'name of model;',
                        type = str, default = 'CE')
    parser.add_argument('--lr', help = 'learning rate for local update',
                        type = float, default = 0.001)
    parser.add_argument('--wd', help = 'weight decay parameter;',
                        type = float, default = 1e-4)
    parser.add_argument('--gpu', help = 'use gpu (default: True)',
                        default = True, action = 'store_true')
    parser.add_argument('--load_model', help = 'if load trained model',
                        type = bool, default = False)

    # parameter
    parser.add_argument('--num_round', help = 'number of rounds to simulate',
                        type = int, default = 20)
    parser.add_argument('--eval_round', help = 'evaluate every ___ rounds',
                        type = int, default = 1)
    parser.add_argument('--batch_size', help = 'batch size when clients train on data',
                        type = int, default = 1024)
    parser.add_argument('--optimizer', help = 'optimizer for training',
                        type = str, default = 'AdamW')
    parser.add_argument('--print_result', help = 'if print the result',
                        type = bool, default = True)

    # Ours
    parser.add_argument('--inner_round', help = 'number of rounds to optimize inner problem',
                        type = int, default = 20)
    parser.add_argument('--dual_params_lr', help = 'personalzied parameter',
                        type = float, default = 0.5)
    parser.add_argument('--dual_params_bound', help = 'personalzied parameter',
                        type = float, default = 2)
    parser.add_argument('--tau', help = 'entropic regularization',
                        type = float, default = 0.02)
    parser.add_argument('--post_num_round', help = 'number of rounds to simulate',
                        type = int, default = 80)
    parser.add_argument('--post_batch_size', help = 'post batch size when training on data',
                        type = int, default = 2048)
    

    # FairProjection (following paper's original setting)
    parser.add_argument('--num_iter', help = 'number of rounds',
                        type = int, default = 100)
    
    parser.add_argument('--div', help = 'distance to use, choose from cross-entropy and kl',
                        type = str, default = 'kl')
    
    parser.add_argument('--tune_threshold', help = 'if tune threshold',
                        type = bool, default = False)
    
    # LinearPost (following paper's original setting)
    parser.add_argument('--calibration', help = 'if calibrated',
                        type = bool, default = True)

    # F-div
    parser.add_argument('--param_lamb', help = 'weight of regularizer',
                        type = float, default = 0.05)
    parser.add_argument('--inner_step', help = 'divergence model training step',
                        type = int, default = 60)
    
    
    args = parser.parse_args()

    return args