
from solver import AICL_solver, RAG_solver
from solver_self_ask import SASE_solver

def get_args():
    import argparse
    parser = argparse.ArgumentParser(description="Running time configurations")

    parser.add_argument('--method', default="RAG", type=str)
    parser.add_argument('--strategy', default="diversity", type=str, choices=["random", "similarity", "diversity"])
    parser.add_argument('--suffix', default="", type=str)
    parser.add_argument('--k', default="3", type=int)
    parser.add_argument('--task', default="ES", type=str)
    parser.add_argument('--mask_rate', default="0.5", type=float)
    parser.add_argument('--is_active_judgment', action='store_true')
    parser.add_argument('--is_full_context', action='store_true')
    parser.add_argument('--model', default="gpt-3.5-turbo", type=str)
    parser.add_argument('--docs_path', default="enter your data path here", type=str)
    parser.add_argument('--retrieve_top_k', default=8, type=int)
    parser.add_argument('--retriever_name', default="base", type=str)
    

    args = parser.parse_args()
    kwargs = vars(args)

    return args, kwargs



if __name__ == '__main__':

    args, _ = get_args()
    args.suffix += '_' + args.task + "_mask_" + str(args.mask_rate) + '_' + str(args.strategy) + '_' +str(args.k)

    if args.method == "aicl":
        method = AICL_solver(args)
    elif args.method == 'RAG':
        method = RAG_solver(args)
    elif args.method == 'self_ask':
        method = SASE_solver(args)
    else:
        raise NotImplementedError(f"The solver {args.method} has not been implemented!")

    method.run()

