import argparse

parser = argparse.ArgumentParser()

# logging arguments
logging_arg = parser.add_argument_group('Logging')
logging_arg.add_argument('--run-name', type=str, default='runX')
logging_arg.add_argument('--hadoop-dir', type=str, default=None)
logging_arg.add_argument('--test-freq', type=int, default=10)
logging_arg.add_argument('--unmute-tqdm', action='store_true')

# dataset arguments
data_arg = parser.add_argument_group('Dataset')
data_arg.add_argument('--db5-data-dir', type=str, default='./datasets/dataset_DB5/')
data_arg.add_argument('--rcsb-data-dir', type=str, default='./datasets/dataset_RCSB/')
data_arg.add_argument('--cv', type=int, default=0, choices=[0, 1, 2])
data_arg.add_argument('--batch-size', type=int, default=4)
data_arg.add_argument('--num-data-workers', type=int, default=2)
data_arg.add_argument('--num-preprocess-workers', type=int, default=64)
data_arg.add_argument('--rotation', type=lambda x: eval(x), default=True)
data_arg.add_argument('--iface-cutoff', type=float, default=3.0)
data_arg.add_argument('--swap-pairs', type=lambda x: eval(x), default=True)
data_arg.add_argument('--apply-filter', type=lambda x: eval(x), default=True)
data_arg.add_argument('--min-iface-size', type=int, default=50)
data_arg.add_argument('--max-iface-size', type=int, default=1E6)
data_arg.add_argument('--max-iface-ratio', type=float, default=0.9)
data_arg.add_argument('--vert-nbr-atoms', type=int, default=16) # KNN for chemical environment
data_arg.add_argument('--num-gdf', type=int, default=16)
data_arg.add_argument('--num-signatures', type=int, default=16)
data_arg.add_argument('--num-lb-basis', type=float, default=0.06)
data_arg.add_argument('--smoothing', type=lambda x: eval(x), default=False)

# network arguments
net_arg = parser.add_argument_group('Network')
net_arg.add_argument('--model', type=str, default='PuzzleDock')
net_arg.add_argument('--h-dim', type=int, default=128)
net_arg.add_argument('--chem-embed-dim', type=int, default=32)
net_arg.add_argument('--geom-feat-dim', type=int, default=32)
net_arg.add_argument('--chem-pooling', type=str, default='mean', choices=['max', 'mean'])
net_arg.add_argument('--num-message-passing-blocks', type=int, default=2)
net_arg.add_argument('--num-propagation-layers', type=int, default=3)
net_arg.add_argument('--propagation-time-scale', type=float, default=10.)
net_arg.add_argument('--apply-band-filter', type=lambda x: eval(x), default=True)
net_arg.add_argument('--band-e-mean', type=float, default=0.1)
net_arg.add_argument('--band-e-std', type=float, default=0.1)
net_arg.add_argument('--num-cross-attn-layers', type=int, default=1)
net_arg.add_argument('--num-attn-heads', type=int, default=4)
net_arg.add_argument('--attn-mechanism', type=str, default='softmax', choices=['softmax', 'sigmoid', 'tanh'])
net_arg.add_argument('--h-dim-div', type=int, default=1)
net_arg.add_argument('--num-smoothing-layers', type=int, default=1)
net_arg.add_argument('--bsp-loss', type=str, default='focal', choices=['focal', 'bce'])
net_arg.add_argument('--metric', type=str, default='AP', choices=['AP', 'AUC'])
net_arg.add_argument('--focal-alpha', type=float, default=0.25)
net_arg.add_argument('--focal-gamma', type=int, default=2)
net_arg.add_argument('--nce-loss', type=lambda x: eval(x), default=True)
net_arg.add_argument('--nce-loss-weight', type=float, default=0.1)
net_arg.add_argument('--nce-num-samples', type=int, default=50)
net_arg.add_argument('--nce-T', type=float, default=10.)
net_arg.add_argument('--attn-loss-weight', type=float, default=0.)
net_arg.add_argument('--dropout', type=float, default=0.1)
net_arg.add_argument('--msgpass-mechanism', type=str, default='harmonic', choices=['harmonic', 'graph', 'none'])
# optimizer arguments
opt_arg = parser.add_argument_group('Optimizer')
opt_arg.add_argument('--optimizer', type=str, default='Adam', choices=['SGD', 'Adam'])
opt_arg.add_argument('--lr', type=float, default=5E-4)
opt_arg.add_argument('--lr-scheduler', type=str, default='CosineAnnealingWarmRestarts',
                     choices=['StepLR', 'CosineAnnealingWarmRestarts'])
opt_arg.add_argument('--lr-t0', type=int, default=20)
opt_arg.add_argument('--lr-eta-min', type=float, default=1E-6)
opt_arg.add_argument('--lr-tmult', type=int, default=1)
opt_arg.add_argument('--lr-step-size', type=int, default=5)
opt_arg.add_argument('--lr-gamma', type=float, default=0.5)
opt_arg.add_argument('--clip-grad-norm', type=float, default=1.)
opt_arg.add_argument('--epochs', type=int, default=40)
opt_arg.add_argument('--fp16', action='store_true')
opt_arg.add_argument('--fine-tune', action='store_true')

# misc arguments
misc_arg = parser.add_argument_group('Misc')
misc_arg.add_argument('--rand-seed', type=int, default=2022)
misc_arg.add_argument('--serial', action='store_true')
misc_arg.add_argument('--restore', type=str, default=None)

# parse args
def get_config():
    args = parser.parse_args()
    return args


