set -euo pipefail

. recognizers/functions.bash

usage() {
  echo "Usage: $0 <base-directory> <language> <architecture> <loss-terms> \\
  <validation-data> <string-length> <trial-no> <init-method> [--no-progress]

Train and evaluate a neural network on a language.

  <base-directory>
    Directory under which all datasets and models are stored.
  <language>
    Name of the language to run on. Corresponds to the name of a directory
    under <base-directory>/languages/.
  <architecture>
    One of transformer, rnn, lstm.
  <loss-terms>
    Any of the following, joined by \`+\` characters:
    - rec: recognition (binary classification with binary cross-entropy loss)
    - lm: language modeling (cross-entropy loss of next symbol)
    - ns: next set prediction (binary cross-entropy loss of whether every
          symbol at every position is valid in the next position)
    Example: rec+lm for recognition and language modeling.
  <validation-data>
    Which validation set to use. One of: validation-short, validation-long.
  <string-length>
    Length of strings.
  <trial-no>
    A number distinguishing this random restart.
  --no-progress
    Don't show progress messages.
"
}

random_sample() {
  python recognizers/neural_networks/random_sample.py "$@"
}

base_dir=${1-}
language=${2-}
architecture=${3-}
loss_terms=${4-}
validation_data=${5-}
string_len=${6-}
trial_no=${7-}
init_method=${8-}
if ! shift 8; then
  usage >&2
  exit 1
fi
progress_args=("$@")

language_dir=$(get_language_dir "$base_dir" "$language" "$string_len")


if [[ "$architecture" == "transformer" ]]; then
  # Sample the number of heads
  num_heads=$(random_sample --int --choices 1 2 4 8)

  # Sample the size of a single head
  head_size=$(random_sample --int --choices 16 32 64 128)

  # Calculate d-model
  d_model=$((num_heads * head_size))
  
  # Completely random 
  ff_mult=$(random_sample --int --choices 2 4 8)
  feedforward_size=$((d_model * ff_mult))

  model_flags=( \
    --num-layers "$(random_sample --int 2 6)" \
    --d-model "$d_model" \
    --num-heads "$num_heads" \
    --feedforward-size "$feedforward_size" \
    --dropout 0.1 \
  )
else
  model_flags=( \
    --num-layers "$(random_sample --int 1 4)" \
    --hidden-units "$(random_sample --int --choices 64 128 256 512)" \
    --dropout 0.1 \
  )
fi

loss_term_flags=()
for loss_term in ${loss_terms//+/ }; do
  case $loss_term in
    rec) ;;
    lm)
      loss_term_flags+=( \
        --use-language-modeling-head \
        --language-modeling-loss-coefficient "$(random_sample --log 0.01 10)" \
      )
      ;;
    ns)
      loss_term_flags+=( \
        --use-next-symbols-head \
        --next-symbols-loss-coefficient "$(random_sample --log 0.01 10)" \
      )
      ;;
    *)
      echo "invalid loss term $loss_term" >&2
      exit 1
      ;;
  esac
done

model_dir=$(get_model_dir "$base_dir" "$language" "$architecture" "$string_len" "$trial_no")
python recognizers/neural_networks/train.py \
  --output "$model_dir" \
  --training-data "$language_dir" \
  --validation-data "$validation_data" \
  --architecture "$architecture" \
  "${model_flags[@]}" \
  --init-scale 1 \
  --init-strategy "$init_method" \
  "${loss_term_flags[@]}" \
  --max-epochs 0 \
  --max-tokens-per-batch "$(random_sample --int 128 4096)" \
  --optimizer Adam \
  --initial-learning-rate "$(random_sample --log 0.0001 0.01)" \
  --gradient-clipping-threshold 5 \
  --early-stopping-patience 10 \
  --learning-rate-patience 5 \
  --learning-rate-decay-factor 0.5 \
  --examples-per-checkpoint 10000 \
  "${progress_args[@]}"
bash recognizers/neural_networks/evaluate_sensitivity.bash "$language_dir" "$model_dir" "$((string_len + 1))"