from agent.args import args, task

ALL_PARAMS = {
    "pointmaze": {
        "lr": 7e-5 * args.lr_ratio,
        "batch_size": args.batch_size,
        "pre_steps": args.pre_steps,
        "post_steps": args.post_steps,
        "train_iters": int(1e4),
        "num_centers": args.num_centers,
        "transformer": {
            "hidden_size": args.hidden_size,
            "nhead": args.n,
            "num_layers": args.n,
            "max_step_len": args.pre_steps + args.post_steps + 1,
        },
        "warmup_steps": int(1e4 * 0.10),
        "decay_begins": int(1e4 * 0.35),
        "decay_ends": int(1e4 * 0.85),
        "mask_begin": args.mask_begin,
        "mask_end": args.mask_end,
    },
    "shadow": {
        "lr": 7e-5,
        "batch_size": args.batch_size,
        "pre_steps": args.pre_steps,
        "post_steps": args.post_steps,
        "train_iters": int(1.5e4),
        "num_centers": args.num_centers,
        "transformer": {
            "hidden_size": args.hidden_size,
            "nhead": args.n,
            "num_layers": args.n,
            "max_step_len": args.pre_steps + args.post_steps + 1,
        },
        "warmup_steps": int(1.5e4 * 0.05),
        "decay_begins": int(1.5e4 * 0.35),
        "decay_ends": int(1.5e4 * 0.85),
        "mask_begin": args.mask_begin,
        "mask_end": args.mask_end,
    },
    "ur5e": {
        "lr": 7e-5 * args.lr_ratio,
        "batch_size": args.batch_size,
        "pre_steps": args.pre_steps,
        "post_steps": args.post_steps,
        "train_iters": int(1.2e4),
        "num_centers": args.num_centers,
        "transformer": {
            "hidden_size": args.hidden_size,
            "nhead": args.n,
            "num_layers": args.n,
            "max_step_len": args.pre_steps + args.post_steps + 1,
        },
        "warmup_steps": int(1.2e4 * 0.05),
        "decay_begins": int(1.2e4 * 0.35),
        "decay_ends": int(1.2e4 * 0.85),
        "mask_begin": args.mask_begin,
        "mask_end": args.mask_end,
    },
}

PARAMS = ALL_PARAMS[task]
