#!/bin/bash


TASK=$1
OBJECTIVE=$2
TOPK=$3
NUM_SAMPLE_VARS=$4
NUM_SAMPLE_ATOMS=$5
EPOCHS=$6
NUM_PARALLEL=$7
CWA=$8
REMOVE_IRRELEVANT_VARS_BOOL=$9
SAMPLING_BACC_NON_PARALLEL_COMPUTE=${10}


case $CWA in
    CWA) CWA_BOOL=True;;
    OWA) CWA_BOOL=False;;
    *) echo "Unknown CWA choices!";;
esac


case $REMOVE_IRRELEVANT_VARS_BOOL in
    True) REMOVE_IRRELEVANT_VARS=noirre;;
    False) REMOVE_IRRELEVANT_VARS=yesirre;;
    *) echo "Unknown REMOVE_IRRELEVANT_VARS choices!";;
esac


case $TASK in
    basic/circle_1) seeds=(6734 5112 6180 27125 7 8231 102 27 195 58);;
    basic/circle_2) seeds=(730156592 33 32 94546 98 32989549 4094010 48940 83835377 8763626);;
    basic/cong_1) seeds=(9833 53199423 946657 4975 29 56 4031234 60 80 3481);;
    basic/eqangle_1) seeds=(66805803 277);;
    basic/eqangle_1) seeds=(4176722 6);;
    basic/eqangle_1) seeds=(39451424 13192);;
    basic/eqangle_1) seeds=(91958 406);;
    basic/eqangle_1) seeds=(369663 859)
        ;;
    basic/eqangle_2) seeds=(2625305 94241);;
    basic/eqangle_2) seeds=(8961 260263);;
    basic/eqangle_2) seeds=(39 8063);;
    basic/eqangle_2) seeds=(184156332 4076);;
    basic/eqangle_2) seeds=(7322 3854581)
        ;;
    basic/eqangle_3) seeds=(1118 2081);;
    basic/eqangle_3) seeds=(5829 84794656);;
    basic/eqangle_3) seeds=(3693729 25063);;
    basic/eqangle_3) seeds=(5163346 4);;
    basic/eqangle_3) seeds=(716 10)
        ;;
    basic/eqratio_1) seeds=(91 51068600);;
    basic/eqratio_1) seeds=(9899765 3);;
    basic/eqratio_1) seeds=(17443 70);;
    basic/eqratio_1) seeds=(9 3674185);;
    basic/eqratio_1) seeds=(24389850 6244)
        ;;
    basic/eqratio_2) seeds=(32 862066053);;
    basic/eqratio_2) seeds=(629 33154911);;
    basic/eqratio_2) seeds=(87269072 3);;
    basic/eqratio_2) seeds=(7790123 6262606);;
    basic/eqratio_2) seeds=(35 40612)
        ;;
    basic/eqratio_3) seeds=(64649106 1);;
    basic/eqratio_3) seeds=(953701992 102575181);;
    basic/eqratio_3) seeds=(884 19993);;
    basic/eqratio_3) seeds=(902464253 45);;
    basic/eqratio_3) seeds=(9327 23112124)
        ;;
    basic/eqratio_4) seeds=(9001 815501938);;
    basic/eqratio_4) seeds=(9199677 7769);;
    basic/eqratio_4) seeds=(151520 68117);;
    basic/eqratio_4) seeds=(632 28420);;
    basic/eqratio_4) seeds=(6832426 37463)
        ;;
    basic/para_1) seeds=(596543 29609 79 6821630 25229 231 282539 3 9590 867);;
    basic/para_2) seeds=(547346382 60170073 5212 8 866 740726912 16989368 40 1013 8720);;
    basic/para_3) seeds=(85 46163 928 45423 72244944 73 284756 4142 96655 2180873);;
    *) echo "Unknown task!";;
esac


for seed in "${seeds[@]}"
do
    # export PYTORCH_JIT=0
    export TORCH_COMPILE_DISABLE=0
    python ../src/main.py \
        --task GeoILP/$TASK \
        --seed $seed \
        --num_intraop_threads 1 \
        --num_interop_threads $NUM_PARALLEL \
        --objective $OBJECTIVE \
        --rloo_topk $TOPK \
        --cwa=$CWA_BOOL \
        --remove_irrelevant_vars=$REMOVE_IRRELEVANT_VARS_BOOL \
        --stop_per_sampling_rules \
        --sampling_bacc_non_parallel_compute $SAMPLING_BACC_NON_PARALLEL_COMPUTE \
        --train_epochs $EPOCHS \
        --max_train_inference_steps 10 \
        --num_rules 4 \
        --max_variables 4 \
        --max_body_atoms 3 \
        --num_aux_predicates 1 \
        --max_aux_arity 4 \
        --max_occurrence_in_body 2 \
        --predicate_embed_dim 5 \
        --variable_embed_dim 10 \
        --rule_head_atom_embed_dim 5 \
        --rule_body_atom_embed_dim 5 \
        --num_sample_vars $NUM_SAMPLE_VARS \
        --num_sample_atoms $NUM_SAMPLE_ATOMS \
        --lr 5e-2 \
        --entropy_coeff_init 1e-2 \
        --entropy_coeff_final 1e-6 \
        --entropy_coeff_anneal_epochs $EPOCHS \
        --log_dir ./log \
        --log_period_to_disk 120 \
        --log_buffer_capacity 10 \
    &
done
wait
