import torch as t


def add_solver_args(parser):
    parser.add_argument(
        "--n-trials",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--d-hidden",
        type=int,
        default=64,
    )
    parser.add_argument(
        "--exploration-constant",
        type=float,
        default=0.1,
    )
    parser.add_argument(
        "--depth",
        type=int,
        default=2,
    )


def add_train_args(parser):
    parser.add_argument(
        "--solver",
        required=True,
    )
    parser.add_argument(
        "--trainer",
        required=True,
    )
    parser.add_argument(
        "--data",
        required=True,
    )
    parser.add_argument(
        "--exp-name",
        default="exp",
    )
    parser.add_argument(
        "--max-iterations",
        type=int,
        default=int(1e5),
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=256,
    )
    parser.add_argument(
        "--learning-rate",
        type=float,
        default=1e-4,
    )
    parser.add_argument(
        "--w-oodl",
        type=float,
        default=1.0,
    )
    parser.add_argument(
        "--w-tdl",
        type=float,
        default=0.1,
    )
    parser.add_argument(
        "--w-rwl",
        type=float,
        default=1.0,
    )
    parser.add_argument(
        "--w-trl",
        type=float,
        default=0.001,
    )
    parser.add_argument(
        "--w-tsl",
        type=float,
        default=0.1,
    )
    parser.add_argument(
        "--n-steps",
        type=int,
        default=5,
    )
    parser.add_argument(
        "--checkpoint-frequency",
        type=int,
        default=1000,
    )
    parser.add_argument(
        "--evaluation-frequency",
        type=int,
        default=1000,
    )
    parser.add_argument(
        "--n-evaluation-runs",
        type=int,
        default=100,
    )
    parser.add_argument(
        "--project",
        default="DOS",
    )
    add_solver_args(parser)


def add_evaluate_args(parser):
    parser.add_argument(
        "--solver",
        required=True,
    )
    parser.add_argument(
        "--weights",
        required=True,
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--save-summary-path",
        default="./exp.summary",
    )
    parser.add_argument(
        "--n-levels",
        type=int,
        default=1000,
    )
    parser.add_argument(
        "--n-workers",
        type=int,
        default=max(1, t.cuda.device_count()),
    )
    add_solver_args(parser)


def add_visualise_args(parser):
    parser.add_argument(
        "--solver",
        required=True,
    )
    parser.add_argument(
        "--weights",
        required=True,
    )
    parser.add_argument(
        "--debug",
        action="store_true",
    )
    add_solver_args(parser)
