import argparse
import os
import shutil

import solver.config


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse_args():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--log-dir', type=str, required=True)
    parser.add_argument('--solver', type=str, choices=['DENSE_CHOLESKY', 'SPARSE_INDIRECT_BLOCK_CG', 'LEAST_SQUARES'], default='LEAST_SQUARES')
    parser.add_argument('--central-diff', type=str2bool, default=False)
    parser.add_argument('--wandb', type=str2bool, default=False)
    parser.add_argument('--gpu-stats-port', type=int, default=8000)
    args = parser.parse_args()
    # shutil.rmtree(args.log_dir, ignore_errors=True)
    os.makedirs(args.log_dir, exist_ok=True)
    solver.config.ODEConfig.linear_solver = eval(f'solver.config.SolverType.{args.solver}')
    solver.config.ODEConfig.central_diff = args.central_diff
    return args


args = parse_args()
