import argparse
from rt.main_synthetic import main
from rt.tasks import all_tasks, forecast_tasks, generate_rel_synthetic_tasks

if __name__ == "__main__":

    offset = 14000
    num_dbs = 2000
    num_test_dbs = 100
    for num_train_dbs in [8, 16, 32, 64, 128, 256, 512, 1024]:
        skip_clf_tasks = False
        skip_reg_tasks = False
        rel_synthetic_tasks = generate_rel_synthetic_tasks(
            offset=offset,
            num_dbs=num_dbs,
            num_train_dbs=num_train_dbs,
            num_test_dbs=num_test_dbs,
            skip_clf_tasks=skip_clf_tasks,
            skip_reg_tasks=skip_reg_tasks,
        )
        syn_autocomplete_clf_tasks = rel_synthetic_tasks["train_autocomplete_clf_tasks"]
        syn_autocomplete_reg_tasks = rel_synthetic_tasks["train_autocomplete_reg_tasks"]
        train_tasks = syn_autocomplete_clf_tasks + syn_autocomplete_reg_tasks
        eval_tasks = [
            t
            for t in forecast_tasks
            if t[0]
            in ["rel-hm", "rel-avito", "rel-stack", "rel-trial", "rel-f1", "rel-amazon"]
        ]
        eval_tasks += rel_synthetic_tasks["test_autocomplete_clf_tasks"]
        eval_tasks += rel_synthetic_tasks["test_autocomplete_reg_tasks"]

        lr = 5e-4
        lr_schedule = True
        use_qk_norm = True

        max_steps_list = [4_001, 8_001, 16_001, 32_001, 64_001, 128_001, 256_001]
        for max_steps in max_steps_list:
            eval_freq = max_steps - 1
            for model_seed in range(1):
                main(
                    # misc
                    project="rt",
                    eval_splits=["val", "test"],
                    eval_freq=eval_freq,
                    eval_pow2=False,
                    max_eval_steps=80,
                    load_ckpt_path=None,
                    save_ckpt_dir=f"~/scratch/ckpts/syn-pt-shuffle-use_qk_norm_{use_qk_norm}_lr_{lr}_lr_schedule_{lr_schedule}_skip_clf_{skip_clf_tasks}-skip_reg_{skip_reg_tasks}-m_seed_{model_seed}_d_seeds_{offset}_{num_dbs}_{num_train_dbs}_{num_test_dbs}_max_steps_{max_steps}",
                    compile_=True,
                    seed=model_seed,
                    # data
                    train_tasks=train_tasks,
                    eval_tasks=eval_tasks,
                    batch_size=128,
                    eval_batch_size=128,
                    num_workers=2,
                    ctx_len=1024,
                    max_local_ctx_len=1024,
                    max_bfs_width=128,
                    use_random_walk=False,
                    use_random_sampling=False,
                    use_connecting_nodes=False,
                    num_walks=20000,
                    walk_length=20,
                    mask_prob=0.0,
                    # optimization
                    lr=lr,
                    lr_schedule=lr_schedule,
                    wd=0.1,
                    max_grad_norm=1.0,
                    max_steps=max_steps,
                    # model
                    embedding_model="all-MiniLM-L12-v2",
                    d_text=384,
                    num_blocks=12,
                    d_model=256,
                    num_heads=8,
                    d_ff=1024,
                    use_temporal_mask=False,
                    use_sw_attn=False,
                    sw_len=None,
                    use_qk_norm=use_qk_norm,
                )
