
data_path="/data/path"
save_dir="/save/dir"
tmp_save_dir="/tmp/save/dir"
tsb_dir="./tsb/dir"



n_gpu=4
MASTER_PORT=10052

finetune_pocket_model="/pocket/model"


fpocket_score="Hydrophobicity score"  # choose in ["Score", "Druggability Score", "Total SASA", "Hydrophobicity score"

batch_size=32
batch_size_valid=32
epoch=200
dropout=0.0
warmup=0.03
update_freq=1
dist_threshold=8.0
recycling=3
lr=1e-4


export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=1
CUDA_VISIBLE_DEVICES="0,1,2,3" python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid,test \
       --num-workers 8 --ddp-backend=c10d \
       --task pocket_ft --loss finetune_mse_pocket --arch pocket_ft  \
       --max-pocket-atoms 256 \
       --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-8 --clip-norm 1.0 \
       --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size --batch-size-valid $batch_size_valid \
       --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --update-freq $update_freq --seed 1 \
       --tensorboard-logdir $tsb_dir \
       --log-interval 100 --log-format simple \
       --validate-interval 1 \
       --best-checkpoint-metric valid_rmse --patience 200 --all-gather-list-size 2048000 \
       --dist-threshold $dist_threshold --recycling $recycling \
       --save-dir $save_dir --tmp-save-dir $tmp_save_dir --keep-last-epochs 5 \
       --find-unused-parameters \
       --finetune-pocket-model $finetune_pocket_model \
       --reg \
       --fpocket-score "$fpocket_score" \
       
       