import argparse
from wandb_train import main
from sweap_predict import sweap_predict
import wandb
import torch





def sweep():
    wandb.init(project="ExerCAKT_assist2009")
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", type=str, default=wandb.config.dataset_name)
    parser.add_argument("--model_name", type=str, default="akt")
    parser.add_argument("--emb_type", type=str, default="qid")
    parser.add_argument("--save_dir", type=str, default=wandb.config.save_dir)
    # parser.add_argument("--learning_rate", type=float, default=1e-5)


    parser.add_argument("--seed", type=int, default=wandb.config.seed)
    parser.add_argument("--fold", type=int, default=wandb.config.fold)
    parser.add_argument("--RNN1type", type=str, default=wandb.config.RNN1type)
    parser.add_argument("--RNN2type", type=str, default=wandb.config.RNN2type)    
    parser.add_argument("--dropout", type=float, default=wandb.config.dropout)
    parser.add_argument("--n_layer", type=int, default=wandb.config.n_layer)
    
    parser.add_argument("--d_model", type=int, default=wandb.config.d_model)
    parser.add_argument("--final_fc_zoom", type=int, default=wandb.config.final_fc_zoom)
    parser.add_argument("--ratio", type=float, default=wandb.config.ratio)    
    parser.add_argument("--ablation", type=str, default=wandb.config.ablation)
    parser.add_argument("--learning_rate", type=float, default=wandb.config.learning_rate)



    parser.add_argument("--use_wandb", type=int, default=0)
    parser.add_argument("--add_uuid", type=int, default=1)
    parser.add_argument("--use_sweep", type=int, default=1)
    args = parser.parse_args()

    params = vars(args)
    model_save_path=main(params)
    

    sweap_params={}
    sweap_params['use_wandb']=0
    sweap_params['use_sweep']=1
    sweap_params['fusion_type']="early_fusion,late_fusion"
    parts=model_save_path.split("/")
    desired_part = "/".join(parts[:-1])
    sweap_params['save_dir']=desired_part
    
    print(model_save_path)
    sweap_params['bz']=512

    sweap_predict(sweap_params)
if __name__ == "__main__":
    
    sweep()