
import argparse


def config():
    parser = argparse.ArgumentParser()

    # Main setting
    parser.add_argument("-cuda", type=str, default="2")
    parser.add_argument('-output_folder', type=str, default="results_ALDH1_try")

    parser.add_argument("-mode", type=str, default="a")   # "a", "e", "d"
    parser.add_argument('-architecture', type=str, default="ginl")
    parser.add_argument('-strategy', type=str, default="grpo")   # "grpo", "greedy", "mi", "uncertainty", "similarity", "random"
    parser.add_argument('-dataset', type=str, default="ALDH1")   # "ALDH1", "PKM2", "VDR", "Enamine50k"

    parser.add_argument('-seed', type=int, default=0)
    parser.add_argument("-start_active_num", type=int, default=1)
    parser.add_argument('-start_num', help='How many molecules we have in our starting set (min=2)', type=int, default=64)
    parser.add_argument('-batch_size', help='How many molecules we select each cycle', type=int, default=64)
    parser.add_argument('-max_screen_size', help='Total budget', type=int, default=1000)
    
    parser.add_argument('-ensemble_size', type=int, default=2)
    parser.add_argument("-epochs", type=int, default=2)
    parser.add_argument('-pretrain_file', type=str, default="pretrain/GraphMVP_simple_features_for_classification/output/3D_hybrid_02_masking/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0/pretraining_model.pth")
    parser.add_argument("-model_save_file", type=str, default="pretrain/Enamine50k_30000_ep1.pth")
    
    # Train setting
    parser.add_argument('-train_batch_size', type=int, default=64)
    parser.add_argument('-infer_batch_size', type=int, default=512)

    parser.add_argument("-mol_emb_dim", type=int, default=130)
    parser.add_argument("-hidden_dim", type=int, default=1024)
    parser.add_argument("-output_dim", type=int, default=2)

    parser.add_argument("-mlp_fc_layer", type=int, default=3)
    parser.add_argument("-gcn_graph_conv_layer", type=int, default=5)
    parser.add_argument("-gcn_x_fc_layer", type=int, default=3)
    parser.add_argument("-gin_graph_conv_layer", type=int, default=3)
    parser.add_argument("-gin_x_fc_layer", type=int, default=3)
    parser.add_argument("-gin_fp_fc_layer", type=int, default=3)
    parser.add_argument("-gine_graph_conv_layer", type=int, default=3)
    parser.add_argument("-gine_x_fc_layer", type=int, default=1)
    parser.add_argument("-gine_fp_fc_layer", type=int, default=1)

    parser.add_argument("-lr", type=float, default=3e-4)
    parser.add_argument("-weight_decay", type=float, default=0)
    parser.add_argument('-retrain', help='Retrain the model every cycle', type=int, default=1)
    parser.add_argument("-anchored", type=bool, default=True)
    parser.add_argument("-l2_lambda", type=float, default=3e-4)
    parser.add_argument("-grpo_lambda", type=float, default=7e-2)
    parser.add_argument("-grpo_epsilon", type=float, default=2e-1)
    parser.add_argument("-grpo_beta", type=float, default=1e-2)


    args = parser.parse_args()

    return args







