wandb_project_name=memory_transformer

task=gcd
encoding_method=standard
epochs=10

# max_steps_per_epoch=1000
batch_size=128

count=4
# for task in sum prod gcd remainder
# do
nvars=2
field=ZZ
max_coefficient=1000
max_degree=20
num_memory_tokens=16
max_seq_len=1024

exp_name=memory_retain
data_class=large
data_name=data_${field}_n=${nvars}
data_path=data/${data_class}/${task}/${data_name}

run_name=${exp_name}_${data_class}_${task}_${field}_n=${nvars}_${encoding_method}_add_memory=${num_memory_tokens}
run_name_sub=${exp_name}2_${field}_n=${nvars}_${encoding_method}_bs=${batch_size}_add_memory=${num_memory_tokens}
save_path=results/${data_class}/${task}/${run_name_sub}

mkdir -p $save_path
CUDA_VISIBLE_DEVICES=$count python3 src/main.py  --save_path $save_path \
                                            --data_path $data_path \
                                            --task $task \
                                            --epochs $epochs \
                                            --batch_size $batch_size \
                                            --exp_name $wandb_project_name \
                                            --exp_id $run_name \
                                            --num_variables $nvars \
                                            --field $field \
                                            --max_coefficient $max_coefficient \
                                            --max_degree $max_degree \
                                            --encoding_method $encoding_method \
                                            --num_memory_tokens $num_memory_tokens \
                                            --use_memory_transformer \
                                            --max_sequence_length $max_seq_len > ${save_path}/run.log &

                                            # --max_steps_per_epoch $max_steps_per_epoch \
                                            # --positional_encoding embedding \
# next_count=$((count+1))
# debug
# CUDA_VISIBLE_DEVICES=$count python3 src/main.py  --save_path $save_path \
#                                             --data_path $data_path \
#                                             --task $task \
#                                             --epochs $epochs \
#                                             --batch_size $batch_size \
#                                             --exp_name $wandb_project_name \
#                                             --exp_id $run_name \
#                                             --num_variables $nvars \
#                                             --field $field \
#                                             --max_coefficient $max_coefficient \
#                                             --max_degree $max_degree \
#                                             --encoding_method $encoding_method \
#                                             --num_memory_tokens $num_memory_tokens\
#                                             --max_sequence_length $max_seq_len \
#                                             --max_steps_per_epoch 1 \

#                                             # --max_steps_per_epoch $max_steps_per_epoch \
#                                             # --positional_encoding embedding \

