import argparse
import warnings


from MCTS import MCTS
from utils import redirect_log_file
from BFSearch import k_fold_cv_bfs
from test_sindy import test_sindy

def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="2d_comp_viscose_newton_ns", choices=["2d_comp_viscose_newton_ns", "2d_comp_viscose_new_non_newton", "2d_heat_comp_v2"])
    parser.add_argument("--datasource", type=str, default="COMSOL")
    parser.add_argument("--method", type=str, default="mcts", choices=["mcts", "es", "sindy"])
    parser.add_argument("--time_limit", type=float, default=None)
    parser.add_argument('--verbose_level', type=int, default=1, choices=[0, 1, 2])

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    warnings.filterwarnings("ignore", category=RuntimeWarning)
    args = arg_parser()
    dataname_tuple = (args.dataset, args.datasource)

    log_root = "./log/"
    exp_name = args.dataset + " " + args.method
    out_file = redirect_log_file(log_root, exp_name)

    if args.method == "mcts":
        tree = MCTS(dataname_tuple=dataname_tuple, out_file=out_file, verbose_level=args.verbose_level, n_jobs=10)
        tree.k_fold_cv_search()
    elif args.method == "es":
        k_fold_cv_bfs(dataname_tuple=dataname_tuple, out_file=out_file, time_limit=args.time_limit, verbose=(args.verbose_level>0))
    elif args.method == "sindy":
        test_sindy(args.dataset, args.datasource)