# To enable each single run on a single thread
export OMP_NUM_THREADS=1 #important for CPU & numpy
#export MKL_NUM_THREADS=1
#export OPENBLAS_NUM_THREADS=1
#export NUMEXPR_NUM_THREADS=1
#export NUM_INTER_THREADS=1
#export NUM_INTRA_THREADS=1

#export NPROC=1 #to avoid multiple processes using the same GPU/CPU
# FOR JAX
# export XLA_FLAGS="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1 inter_op_parallelism_threads=1 --xla_force_host_platform_device_count=1"


RUN=pyrun_train_MPE_mol_uncon.py
BIN=../source/main_gen_mol_uncon.py
SCRIPT_NAME=$(basename $0 .sh)
#FOLDER_NAME=$SCRIPT_NAME
FOLDER_NAME=demo_train_qm9
THREADS=1

#SEEDS='10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49'
#SEEDS='0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49'
#SEEDS='0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19'
SEEDS='0,1,2,3,4,5,6,7,8,9'

DATNAME='qm9'
#ANSATZ='rxycz'
ANSATZ='SU2-full'

NATOMS=7
NRINGS=2
NQUBITS=7
NANCILA=3

NTRAIN=200
NTEST=488

EPOCHS=2001
DIST_TYPE='wass'
LR='0.001' # Learning rate
MAG='1.0'  # Magnitude of initial parameters
VENDI_LAMBDA='0.0001' # Vendi loss lambda

GEN_LAYERS='20'
BLOCH=0

for STEPS in 10
do
for INPUT in product
do
SAVE=../results/$FOLDER_NAME/$DATNAME\_qubits_$NQUBITS\_$NANCILA\_dat_$NTEST\_$NTRAIN
# Using taskset to limit the CPU cores used by the script
taskset -c 114-123 python $RUN --bin $BIN --vendi_lambda $VENDI_LAMBDA --n_qubits $NQUBITS --gen_circuit_type $ANSATZ --n_atoms $NATOMS --n_rings $NRINGS  --threads $THREADS --rseed $SEEDS --n_outer_epochs $EPOCHS --dat_name $DATNAME --bloch $BLOCH --n_diff_steps $STEPS --n_ancilla $NANCILA --lr $LR --mag $MAG --save_dir $SAVE --input_type $INPUT --n_train $NTRAIN --n_test $NTEST --n_layers $GEN_LAYERS --dist_type $DIST_TYPE
done
done
