import argparse

from wandb_train import main

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # meta
    parser.add_argument("--dataset_name", type=str, default="assist2009")
    parser.add_argument("--model_name", type=str, default="ktst")
    parser.add_argument("--save_dir", type=str, default="saved_model")
    parser.add_argument("--seed", type=int, default=3407)
    parser.add_argument("--fold", type=int, default=0)
    parser.add_argument("--use_wandb", type=int, default=0)
    parser.add_argument("--add_uuid", type=int, default=0)

    # optimizer
    parser.add_argument("--learning_rate", type=float, default=5.0e-05)
    parser.add_argument("--weight_decay", type=float, default=1.0e-05)

    # model
    parser.add_argument("--d_model", type=int, default=256)
    parser.add_argument("--nhead_tf", type=int, default=8)
    parser.add_argument("--nhead_agg", type=int, default=4)
    parser.add_argument("--num_layers_tf", type=int, default=6)
    parser.add_argument("--num_layers_agg", type=int, default=2)
    parser.add_argument("--dim_feedforward", type=int, default=256)
    parser.add_argument("--dropout", type=float, default=0.05)
    parser.add_argument("--dim_classifier", type=int, default=128)
    parser.add_argument("--aggregation", type=str, default="q_mean_c")
    parser.add_argument("--use_bias_emb", type=bool, default=True)
    parser.add_argument("--use_zero_init", type=bool, default=True)
    parser.add_argument(
        "--attn_variant", type=str, default="learnable_alibi_monotonic_q_k"
    )

    # mandatory arguments, but not used
    parser.add_argument("--emb_type", type=str, default="qid")

    args = parser.parse_args()

    params = vars(args)
    main(params)
