#!/bin/bash

# Define the list of Python scripts to run
# Define pairs of context_class and target_class
PARAM_SETS=(
    # "--context_class 2 --target_class 1"
    # "--context_class 1 --target_class 2"
    # "--context_class 0 --target_class 1"
    # "--context_class 0 --target_class 2"
    # "--context_class 1 --target_class 0"
    # "--context_class 2 --target_class 0"
)

SETTINGS=(
    "ddr"
)

MODEL_FLAGS="--attention_resolutions 32,16,8 --diffusion_steps 1000 --image_size 32 --learn_sigma False --noise_schedule linear --num_channels 128 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_scale_shift_norm True --num_classes 3"

for settings in "${SETTINGS[@]}"; do
for params in "${PARAM_SETS[@]}"; do

    echo "Running sample script with settings: $settings and params: $params"

    # pick a free port for this rank
    export MASTER_PORT=$(
    python3 - <<PYCODE
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind(('', 0))
    print(s.getsockname()[1])
PYCODE
    )
    torchrun --nproc-per-node=1 --master-port=$MASTER_PORT sample_scripts/vd_image_sample_general_multihop.py --attention_resolutions 32,16,8 --diffusion_steps 1000 --image_size 32 --learn_sigma False --noise_schedule linear --num_channels 192 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_scale_shift_norm True --num_classes 3 --in_channels 8 --batch_size 64 --input_dir /scratch/s223719687/DiffusionRouterModel/datasets/Faces_dataset/face_sketch_segment/test --save_dir "" --model_path "" --timestep_respacing 1000 --use_ddim True --dataset_name "face_sketch_segment" --latent_space --clip_denoised False $MODEL_FLAGS  $params
done
done