import os

def run_func(description, ppi_path, pseq_path, vec_path,
            split_new, split_mode, train_valid_index_path,
            use_lr_scheduler, save_path, graph_only_train, 
            batch_size, epochs,num_token,hidden,lr,gin_num_layer,th_epoch,use_jk,use_GRU,model_type):
    os.system("python -u gnn_prompt_train_binding.py \
            --description={} \
            --ppi_path={} \
            --pseq_path={} \
            --vec_path={} \
            --split_new={} \
            --split_mode={} \
            --train_valid_index_path={} \
            --use_lr_scheduler={} \
            --save_path={} \
            --graph_only_train={} \
            --batch_size={} \
            --epochs={} \
            --num_token={} \
            --hidden={} \
            --lr={} \
            --gin_num_layer={} \
            --th_epoch={} \
            --use_jk={} \
            --use_GRU={} \
            --model_type={} \
            ".format(description, ppi_path, pseq_path, vec_path, 
                    split_new, split_mode, train_valid_index_path,
                    use_lr_scheduler, save_path, graph_only_train, 
                    batch_size, epochs, num_token,hidden,lr,gin_num_layer,th_epoch,use_jk,use_GRU,model_type))

if __name__ == "__main__":
    model_type = "prompt"  # gnn, prompt 
    description = "random" # random, bfs, dfs


    ppi_path = "data/protein.actions.yeast.tsv"
    pseq_path = "data/protein.yeast.sequences.dictionary.tsv"
    
    # ppi_path = "data/protein.actions.SHS27k.STRING.txt"
    # pseq_path = "data/protein.SHS27k.sequences.dictionary.tsv"

    # vec_path = "data/CTCoding_onehot.txt"
    vec_path = "data/vec5_CTC.txt"

    split_new = "True"
    split_mode = description
    train_valid_index_path = "./train_valid_index_json/"+description+".json"

    use_lr_scheduler = "True"
    save_path = "./save_model/"
    graph_only_train = "True"
    

    if model_type == "prompt":
        batch_size = 128
        epochs = 3000
        num_token = 8
        hidden = 128
        lr = 0.0003
        gin_num_layer = 2
        th_epoch = 300
        use_jk = False
        use_GRU = False
    elif model_type == "mlp":
        batch_size = 32
        epochs = 3000
        num_token = 7
        hidden = 256
        lr = 0.00005
        gin_num_layer = 3
        th_epoch = 50
        use_jk = False
        use_GRU = False
    else:  # gnn
        batch_size = 64
        epochs = 3000
        hidden = 256
        lr = 0.003
        gin_num_layer = 3
        num_token = 7
        th_epoch = 50
        use_jk = False
        use_GRU = False

    run_func(description, ppi_path, pseq_path, vec_path, 
            split_new, split_mode, train_valid_index_path,
            use_lr_scheduler, save_path, graph_only_train, 
            batch_size, epochs,num_token,hidden,lr,gin_num_layer,th_epoch,use_jk,use_GRU,model_type)