
import torch,os

def parse_default_config(parser):
    parser.add_argument('--seed', default='123', type=int, help='random seed')         
    return parser

def parser_config(parser):
    parse_default_config(parser)

    parser.add_argument('--exp_id', default='test', type=str, help='exp id')     
    parser.add_argument('--datan', default='-1', type=str, help='the name of dataset')
    parser.add_argument('-m', '--modelname', default="-1", type=str, help='the name of model')

    parser.add_argument('--cuda', default=-1, type=int, help='cuda')
    parser.add_argument('--beam_size', default=150, type=int, help='beam size')
    parser.add_argument('--topk', default=20, type=int, help='retrieve number')

    parser.add_argument('--bs', default=200, type=int, help='batch size in each prediction')

    parser.add_argument('--filiting_list_path', default='-1', type=str, help='for train ranker')


    

    return parser

def config_args(args):
    if args.cuda == -1:
        args.device = "cpu"
    else:
        args.device = f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"

    if args.beam_size < args.topk:
        args.beam_size = args.topk
        
    return args

        