DATASET="cifar10"
DEVICE="cuda"
NUM_CLASSES=10
DATASET_PATH="./datasets"

PAD_SIZE=8
NUM_IMAGE_LOCATIONS="edges"
BACKGROUND="nature"

BASE_CLASSIFIER="cifar_resnet110"
NUM_QUERIES=2
BUDGET_SPLIT="fixed"
FIRST_QUERY_BUDGET_FRAC=0.5
SECOND_QUERY_MASK_MODEL="unet"
MASK_OUTPUT="sigmoid"
MASK_INIT="random"
WM_ACROSS_CHANNELS="same"
TOTAL_SIGMA=0.75
RUN_DESCRIPTION="description"

SEED=1

python3 src/train_with_mask.py \
        --seed $SEED \
        --device $DEVICE \
        --epochs 100 \
        --dataset $DATASET \
        --num_classes $NUM_CLASSES \
        --dataset_path $DATASET_PATH \
        --pad_size $PAD_SIZE \
        --num_image_locations $NUM_IMAGE_LOCATIONS \
        --background $BACKGROUND \
        --base_classifier $BASE_CLASSIFIER \
        --train_batch_size 256 \
        --test_batch_size 20 \
        --optim_lr 0.01 \
        --weight_decay 1e-4 \
        --momentum 0.9 \
        --step_size 30 \
        --gamma 0.1 \
        --t_optim_lr 1e-3 \
        --t_weight_decay 1e-4 \
        --t_momentum 0.9 \
        --t_step_size 40 \
        --t_gamma 0.5 \
        --num_queries $NUM_QUERIES \
        --budget_split $BUDGET_SPLIT \
        --first_query_budget_frac $FIRST_QUERY_BUDGET_FRAC \
        --second_query_mask_model $SECOND_QUERY_MASK_MODEL \
        --mask_output $MASK_OUTPUT \
        --average_queries \
        --averaging_style "old" \
        --mask_init $MASK_INIT \
        --wm_across_channels $WM_ACROSS_CHANNELS \
        --total_sigma $TOTAL_SIGMA \
        --mask_recon 0 \
        --run_description $RUN_DESCRIPTION
