import argparse

def get_parser():
    parser = argparse.ArgumentParser(description='Precompute ncb encodings of different types')

    # Required/Base arguments
    base_group = parser.add_argument_group('Base arguments')
    base_group.add_argument('--enc_type', type=str, choices=['concept_slot', 'one_hot', 'one_hot_padded', 'block_slot', 'retrieval_corpus'])
    base_group.add_argument('--data_dir', type=str)
    base_group.add_argument('--result_dir', type=str)
    base_group.add_argument('--num_workers', type=int, default=4)
    base_group.add_argument('--batch_size', type=int, default=128)

    # Sysbinder arguments
    sysbinder_group = parser.add_argument_group('Sysbinder arguments')
    sysbinder_group.add_argument('--sysbinder_path', type=str)
    sysbinder_group.add_argument('--model_seed', type=int, default=0)
    sysbinder_group.add_argument('--image_size', type=int, default=128)
    sysbinder_group.add_argument('--image_channels', type=int, default=3)
    sysbinder_group.add_argument('--num_iterations', type=int, default=3)
    sysbinder_group.add_argument('--num_slots', type=int, default=4)
    sysbinder_group.add_argument('--num_blocks', type=int, default=16)
    sysbinder_group.add_argument('--cnn_hidden_size', type=int, default=512)
    sysbinder_group.add_argument('--slot_size', type=int, default=2048)
    sysbinder_group.add_argument('--mlp_hidden_size', type=int, default=192)
    sysbinder_group.add_argument('--num_prototypes', type=int, default=64)
    sysbinder_group.add_argument('--temp', type=float, default=1., help='softmax temperature for prototype binding')
    sysbinder_group.add_argument('--vocab_size', type=int, default=4096)
    sysbinder_group.add_argument('--num_decoder_layers', type=int, default=8)
    sysbinder_group.add_argument('--num_decoder_heads', type=int, default=4)
    sysbinder_group.add_argument('--d_model', type=int, default=192)
    sysbinder_group.add_argument('--dropout', type=float, default=0.1)
    sysbinder_group.add_argument('--binarize', default=False, action='store_true',
                                help='Should the encodings of the sysbinder be binarized?')

    # Retrieval corpus arguments
    rc_group = parser.add_argument_group('Retrieval corpus arguments')
    rc_group.add_argument('--retrieval_corpus_path', type=str)
    rc_group.add_argument('--retrieval_encs', default='proto',
                         choices=['proto', 'exem', 'basis', 'proto-exem', 'proto-exem-basis'])
    rc_group.add_argument('--majority_vote', default=False, action='store_true',
                         help='If set then the retrieval binder takes the majority vote of the topk nearest encodings')
    rc_group.add_argument('--topk', type=int, default=5,
                         help='Number of nearest encodings to consider for majority vote')
    rc_group.add_argument('--thresh_count_obj_slots', type=int, default=-1,
                         help='Threshold value for determining object slots (-1 to use all slots)')

    return parser