
import argparse
from wandb_train import main
import wandb
import torch





def sweep():
    wandb.init(project="mlp4kt")
    # device = torch.device("cuda")
    # torch.cuda.reset_peak_memory_stats(device)
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, default="assist2009")
    parser.add_argument("--model_name", type=str, default="lkt")
    parser.add_argument("--emb_type", type=str, default="qid")
    parser.add_argument("--save_dir", type=str, default="saved_model")
    # parser.add_argument("--learning_rate", type=float, default=1e-5)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--fold", type=int, default=0)
    parser.add_argument("--dropout", type=float, default=wandb.config.dropout)
    parser.add_argument("--batch_size", type=int, default=wandb.config.batch_size)

    parser.add_argument("--final_fc_dim", type=int, default=wandb.config.final_fc_dim)
    parser.add_argument("--final_fc_dim2", type=int, default=wandb.config.final_fc_dim2)

    parser.add_argument("--mixer_ratio1", type=float, default=wandb.config.mixer_ratio1)
    parser.add_argument("--mixer_ratio2", type=float, default=wandb.config.mixer_ratio2)

    
    parser.add_argument("--d_model", type=int, default=wandb.config.d_model)
    parser.add_argument("--d_ff", type=int, default=256)


    parser.add_argument("--learning_rate", type=float, default=wandb.config.learning_rate)

    parser.add_argument("--use_wandb", type=int, default=0)
    parser.add_argument("--add_uuid", type=int, default=1)
    parser.add_argument("--use_sweep", type=int, default=1)
    
    args = parser.parse_args()

    params = vars(args)
    main(params)

if __name__ == "__main__":
    
    sweep()