import argparse
from rt.main import main
from rt.tasks import all_tasks, forecast_tasks

if __name__ == "__main__":

    use_qk_norm = True
    ckpt = 50000
    load_ckpt_path = ""  # trained model path
    model_name = load_ckpt_path.split("/")[-2] + f"_ckpt_{ckpt}"
    for leave_db in [
        "rel-amazon",
        "rel-hm",
        "rel-avito",
        "rel-trial",
        "rel-stack",
        "rel-f1",
    ]:
        print(f"Continued Pre-training without {leave_db}")

        main(
            # misc
            project="rt",
            eval_splits=["val", "test"],
            eval_freq=2000,
            eval_pow2=False,
            max_eval_steps=40,
            load_ckpt_path=load_ckpt_path,
            save_ckpt_dir=f"~/scratch/ckpts/baselines_use_qk_norm_{use_qk_norm}/cpt/{model_name}/leave_{leave_db}",
            compile_=True,
            seed=0,
            # data
            train_tasks=[t for t in all_tasks if t[0] != leave_db],
            eval_tasks=[t for t in forecast_tasks if t[0] == leave_db],
            batch_size=128,
            num_workers=2,
            ctx_len=1024,
            max_local_ctx_len=1024,
            max_bfs_width=128,
            use_random_walk=False,
            use_random_sampling=False,
            use_connecting_nodes=False,
            num_walks=20000,
            walk_length=20,
            mask_prob=0.0,
            # optimization
            lr=5e-4,
            wd=0.1,
            lr_schedule=True,
            max_grad_norm=1.0,
            max_steps=50_001,
            # model
            embedding_model="all-MiniLM-L12-v2",
            d_text=384,
            num_blocks=12,
            d_model=256,
            num_heads=8,
            d_ff=1024,
            use_temporal_mask=False,
            use_sw_attn=False,
            sw_len=None,
            use_qk_norm=use_qk_norm,
        )
