import argparse

def add_data_group(group):
    group.add_argument('--seed', type=int, default=123)
    group.add_argument('--dataset', type=str, default='FRANKENSTEIN', help="used dataset",choices=['AIDS','PROTEINS_full','FRANKENSTEIN'])
    group.add_argument('--data_path', type=str, default='../dataset', help="the directory used to save dataset")
    group.add_argument('--use_nlabel_asfeat', action='store_true', help="use node labels as (part of) node features")
    # group.add_argument('--use_nlabel_asfeat', default=True, help="use node labels as (part of) node features")
    group.add_argument('--use_org_node_attr', action='store_true', help="use node attributes as (part of) node features")
    group.add_argument('--use_degree_asfeat', action='store_true', help="use node degrees as (part of) node features")
    group.add_argument('--data_verbose', action='store_true', help="print detailed dataset info")
    group.add_argument('--save_data', action='store_true')


def add_model_group(group):
    group.add_argument('--model', type=str, default='gcn', help="used model")
    group.add_argument('--train_ratio', type=float, default=0.5, help="ratio of trainset from whole dataset")
    group.add_argument('--hidden_dim', nargs='+', default=[64, 16], type=int, help='constrain how much products a vendor can have')
    #group.add_argument('--num_head', type=int, nargs='+', default=[8,1], help="GAT head number")
    group.add_argument('--num_head', type=int, default=2, help="GAT head number")

    group.add_argument('--batch_size', type=int, default=128)
    group.add_argument('--train_epochs', type=int, default=40)
    group.add_argument('--lr', type=float, default=0.01)
    group.add_argument('--lr_decay_steps', nargs='+', default=[80,120], type=int)
    group.add_argument('--weight_decay', type=float, default=5e-4)
    group.add_argument('--dropout', type=float, default=0.5)
    group.add_argument('--train_verbose', default=True, help="print training details")
    group.add_argument('--log_every', type=int, default=1, help='print every x epoch')
    group.add_argument('--eval_every', type=int, default=5, help='evaluate every x epoch')

    group.add_argument('--clean_model_save_path', type=str, default='../save/model/clean')
    group.add_argument('--save_clean_model', action='store_true')
    group.add_argument('--readdetection', action='store_true')
    group.add_argument('--readtrigger', action='store_true')
    group.add_argument('--restart', action='store_true')
    group.add_argument('--retest', action='store_true')


def add_atk_group(group):
    group.add_argument('--bkd_gratio_train', type=float, default=0.05, help="backdoor graph ratio in trainset")
    group.add_argument('--save_gratio_train', type=float, default=0.05, help="last time for train trigger")
    group.add_argument('--bkd_gratio_test', type=float, default=0.5, help="backdoor graph ratio in testset")
    group.add_argument('--bkd_num_pergraph', type=int, default=1, help="number of backdoor triggers per graph")
    group.add_argument('--bkd_size', type=int, default=4, help="number of nodes for each trigger")
    group.add_argument('--target_class', type=int, default=0, help="the targeted node/graph label.If dataset is PROTEINS_full,target label is 1, else 0")
     
    group.add_argument('--gtn_layernum', type=int, default=3, help="layer number of GraphTrojanNet")
    group.add_argument('--pn_rate', type=float, default=1, help="ratio between trigger-embedded graphs (positive) and benign ones (negative)")
    group.add_argument('--gtn_input_type', type=str, default='2hop', help="how to process org graphs before inputting to GTN")

    group.add_argument('--resample_steps', type=int, default=1, help="# iterations to re-select graph samples")
    group.add_argument('--bilevel_steps', type=int, default=4, help="# bi-level optimization iterations")
    group.add_argument('--gtn_lr', type=float, default=0.01)
    group.add_argument('--feat_lr', type=float, default=0.01)
    group.add_argument('--gtn_epochs', type=int, default=20, help="# attack epochs")
    group.add_argument('--feat_epochs', type=int, default=20, help="# attack epochs")
    group.add_argument('--notopo', action='store_true')
    group.add_argument('--nofeat', action='store_true')
    group.add_argument('--topo_activation', type=str, default='sigmoid', help="activation function for topology generator")
    group.add_argument('--feat_activation', type=str, default='sigmoid', help="activation function for feature generator")
    group.add_argument('--topo_thrd', type=float, default=0.5, help="threshold for topology generator")
    group.add_argument('--feat_thrd', type=float, default=0, help="threshold for feature generator (only useful for binary feature)")

    group.add_argument('--lambd', type=float, default=1, help="a hyperparameter to balance attack loss components")
    # group.add_argument('--atk_verbose', action='store_true', help="print attack details")
    group.add_argument('--save_bkd_model', action='store_true')
    group.add_argument('--bkd_model_save_path', type=str, default='../save/model/bkd')
    group.add_argument('--bkd_data_save_path', type=str, default='../save/model/data')
    group.add_argument('--cleanlabel', type=int, default=0)
    group.add_argument('--chose', type=str, default='con',choices=['random', 'con','loss'],
                       help="how to chose backdoor graphs")
    group.add_argument('--pos', type=str, default='import', choices=['random', 'import','least'],
                       help="how to chose backdoor nodes")
def add_det_group(parser):
    #parser.add_argument('--dataset', type=str, default='AIDS')  # mutag
    parser.add_argument('--SIGNETbatch_size', type=int, default=128)
    parser.add_argument('--batch_size_test', type=int, default=9999)
    parser.add_argument('--log_interval', type=int, default=1)
    parser.add_argument('--num_trials', type=int, default=1)
    parser.add_argument('--device', type=int, default=0)
    #parser.add_argument('--lr', dest='lr', type=float, default=0.001)
    parser.add_argument('--SIGNET_lr', type=float, default=0.0001)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--encoder_layers', type=int, default=5)
    parser.add_argument('--SIGNEThidden_dim', type=int, default=16)
    parser.add_argument('--pooling', type=str, default='add', choices=['add', 'max'])
    parser.add_argument('--readout', type=str, default='concat', choices=['concat', 'add', 'last'])
    parser.add_argument('--explainer_model', type=str, default='gin', choices=['mlp', 'gin'])
    parser.add_argument('--explainer_layers', type=int, default=5)
    parser.add_argument('--explainer_hidden_dim', type=int, default=8)
    parser.add_argument('--explainer_readout', type=str, default='add', choices=['concat', 'add', 'last'])
    parser.add_argument('--beta', type=float, default=1.0)
    parser.add_argument('--beta2', type=float, default=1.0)
    parser.add_argument('--alpha', type=float, default=1.0)
    parser.add_argument('--alpha2', type=float, default=1.0)

def add_new_group(parser):
    parser.add_argument('--new_model_save_path', type=str, default='../save/model/new')
    parser.add_argument('--save_new_model', action='store_true')
    parser.add_argument('--nostop', action='store_true')
    parser.add_argument('--earlystop', type=int, default=100)
def parse_args():
    parser = argparse.ArgumentParser()
    data_group = parser.add_argument_group(title="Data-related configuration")
    model_group = parser.add_argument_group(title="Model-related configuration")
    atk_group = parser.add_argument_group(title="Attack-related configuration")
    det_group = parser.add_argument_group(title="Detection-related configuration")
    new_group=parser.add_argument_group(title="retrain after detection configuration")
    add_data_group(data_group)
    add_model_group(model_group)
    add_atk_group(atk_group)
    add_det_group(det_group)
    add_new_group(new_group)

    return parser.parse_args()
