# MIN_WAIT=60 MAX_WAIT=300 bash scripts/osync.sh --on-changes --initiator=/home/aswerdlo/hdd/data/medim/ckpts/sync --target=ssh://mprabhud@grogu//grogu/user/mprabhud/aswerdlo/medim/ckpts/sync

export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$LD_LIBRARY_PATH"
export CUDA_HOME=$CONDA_PREFIX
export UNIDISC_FORCE_CUDNN_SPDA_CONTEXT=0
export NUM_GPUS=${NUM_GPUS:-4}
export CONSTRAINT="L40|L40S|A100_40GB|A100_80GB|6000Ada|A6000|A4500"
export MEM_PER_GPU=32
export CPUS_PER_GPU=8

export CKPT_DIR='/home/aswerdlo/repos/unidisc_arxiv'

RUN_NAR=${RUN_NAR:-0}
RUN_AR=${RUN_AR:-0}
RUN_CC=${RUN_CC:-0}
RUN_DB=${RUN_DB:-0}
RUN_FLICKR=${RUN_FLICKR:-0}
RUN_COCO=${RUN_COCO:-0}
RUN_MEDIUM=${RUN_MEDIUM:-0}

common_args=(\
debug=true \
model=$([[ "$RUN_MEDIUM" -eq 1 ]] && echo "medium" || echo "small") \
loader.eval_batch_size=$([[ "$RUN_MEDIUM" -eq 1 ]] && echo "3" || echo "24") \
trainer.compile=true \
+trainer.forced_keys='[eval.cfg,eval.unconditional_fid,sampling.predictor,data.fid_dataset,sampling.sampling_step_frac]' \
model.force_optimized_native_attn=false \
wandb.project='medim-jan-eval-ablations' \
partition=preempt \
wandb.tags='[11_12_fid_ar_v2]' \
eval.fid_samples=16384 \
sampling.predictor=maskgit \
sampling.sampling_step_frac='0.05' \
eval.cfg=2 \
trainer.compile=false \
slurm_name="${USER}_ablations_nar" \
mem_per_gpu=$MEM_PER_GPU \
cpus_per_gpu=$CPUS_PER_GPU \
devices=$NUM_GPUS \
constraint=$CONSTRAINT \
partition=general)

common_a_args=(\
+experiments='[small_scale_train,paired_standalone_fid_eval,master_eval,fid_hf]' data.fid_dataset="sayakpaul/coco-30-val-2014")

common_b_args=(\
+experiments='[small_scale_train,paired_standalone_fid_eval,master_eval,fid_hf]' data.fid_dataset="nlphuji/flickr30k")

common_c_args=(\
+experiments='[small_scale_train,paired_standalone_fid_eval,master_eval,fid_cc12m]')

common_d_args=(\
+experiments='[small_scale_train,paired_standalone_fid_eval,master_eval,fid_datacomp1b]')

if [ "$RUN_MEDIUM" -eq 1 ]; then
    NAR_CKPT="$CKPT_DIR/300m_nar.safetensors"
    AR_CKPT="$CKPT_DIR/300m_ar.safetensors"
else
    AR_CKPT="$CKPT_DIR/115m_ar.safetensors"
    NAR_CKPT="$CKPT_DIR/115m_nar.safetensors"
fi

echo "RUN_AR: ${RUN_AR}, RUN_NAR: ${RUN_NAR}, RUN_MEDIUM: ${RUN_MEDIUM}"
echo "RUN_CC: ${RUN_CC}, RUN_DB: ${RUN_DB}, RUN_FLICKR: ${RUN_FLICKR}, RUN_COCO: ${RUN_COCO}"
echo "NAR_CKPT: ${NAR_CKPT}"
echo "AR_CKPT: ${AR_CKPT}"

if [ "$RUN_AR" -eq 1 ]; then
    if [ "$RUN_COCO" -eq 1 ]; then
        python main.py "${common_a_args[@]}" "${common_args[@]}" $@ parameterization=ar trainer.compile=false wandb.name="1_2_ar_60k" \
        trainer.load_from_state_dict="$AR_CKPT" --multirun > /dev/null 2>&1 &
    fi

    if [ "$RUN_FLICKR" -eq 1 ]; then
        python main.py "${common_b_args[@]}" "${common_args[@]}" $@ parameterization=ar trainer.compile=false wandb.name="1_2_ar_60k" \
        trainer.load_from_state_dict="$AR_CKPT" --multirun > /dev/null 2>&1 &
    fi

    if [ "$RUN_CC" -eq 1 ]; then
        echo "RUN_CC: ${RUN_CC}"
        python main.py "${common_c_args[@]}" "${common_args[@]}" $@ parameterization=ar trainer.compile=false wandb.name="1_2_ar_60k" \
        trainer.load_from_state_dict="$AR_CKPT" --multirun
    fi

    if [ "$RUN_DB" -eq 1 ]; then
        python main.py "${common_d_args[@]}" "${common_args[@]}" $@ parameterization=ar trainer.compile=false wandb.name="1_2_ar_60k" \
        trainer.load_from_state_dict="$AR_CKPT" --multirun > /dev/null 2>&1 &
    fi
fi

if [ "$RUN_NAR" -eq 1 ]; then
    if [ "$RUN_COCO" -eq 1 ]; then
        python main.py "${common_a_args[@]}" "${common_args[@]}" $@ wandb.name="1_2_nar_325k" \
        trainer.load_from_state_dict="$NAR_CKPT" --multirun > /dev/null 2>&1 &
    fi

    if [ "$RUN_FLICKR" -eq 1 ]; then
        python main.py "${common_b_args[@]}" "${common_args[@]}" $@ wandb.name="1_2_nar_325k" \
        trainer.load_from_state_dict="$NAR_CKPT" --multirun > /dev/null 2>&1 &
    fi

    if [ "$RUN_CC" -eq 1 ]; then
        python main.py "${common_c_args[@]}" "${common_args[@]}" $@ wandb.name="1_2_nar_325k" \
        trainer.load_from_state_dict="$NAR_CKPT" --multirun > /dev/null 2>&1 &
    fi

    if [ "$RUN_DB" -eq 1 ]; then
        python main.py "${common_d_args[@]}" "${common_args[@]}" $@ wandb.name="1_2_nar_325k" \
        trainer.load_from_state_dict="$NAR_CKPT" --multirun > /dev/null 2>&1 &
    fi
fi
