set -euo pipefail
. recognizers/functions.bash

usage() {
  echo "Usage: $0 <base-directory> <language> <architecture> <loss-terms> \\
  <validation-data> <trial-no> <parameter-budget> [options...]

Train and evaluate a neural network on a language.

  <base-directory>
    Directory under which all datasets and models are stored.
  <language>
    Name of the language to run on. Corresponds to the name of a directory
    under <base-directory>/languages/.
  <architecture>
    One of transformer, rnn, lstm.
  <loss-terms>
    Any of the following, joined by \`+\` characters:
    - rec: recognition (binary classification with binary cross-entropy loss)
    - lm: language modeling (cross-entropy loss of next symbol)
    - ns: next set prediction (binary cross-entropy loss of whether every
          symbol at every position is valid in the next position)
    Example: rec+lm for recognition and language modeling.
  <validation-data>
    Which validation set to use. One of: validation-short, validation-long.
  <trial-no>
    A number distinguishing this random restart.
  <parameter-budget>
    Number of parameters.

  Options:
  --no-progress
    Don't show progress messages.  
  --wandb-agent <id>
    Username for wandb.
"
}

random_sample() {
  python recognizers/neural_networks/random_sample.py "$@"
}

base_dir=${1-}
language=${2-}
architecture=${3-}
loss_terms=${4-}
validation_data=${5-}
trial_no=${6-}
parameter_budget=${7-}
if ! shift 7; then
  usage >&2
  exit 1
fi
extra_train_args=("$@")

language_dir=$(get_language_dir "$base_dir" "$language")

case "$architecture" in
  transformer)
    model_size_args=( \
      --num-layers 4 \
      --num-heads 4 \
    )
    model_train_args=( \
      --init-scale 0.15 \
      --initial-learning-rate 0.001 \
      --dropout 0.2 \
    )
    ;;
  rnn)
    model_size_args=( \
      --num-layers 2 \
    )
    model_train_args=( \
      --init-scale 0.15 \
      --initial-learning-rate 0.002 \
      --dropout 0.2 \
    )
    ;;
  lstm)
    model_size_args=( \
      --num-layers 2 \
    )
    model_train_args=( \
      --init-scale 0.15 \
      --initial-learning-rate 0.002 \
      --dropout 0.2 \
    )
    ;;
  *)
    usage >&2
    echo "unknown architecture: $architecture" >&2
    exit 1
    ;;
esac
model_args=($( \
  python recognizers/neural_networks/get_architecture_args.py \
    --architecture "$architecture" \
    --parameter-budget "$parameter_budget" \
    --training-data "$language_dir" \
    "${model_size_args[@]}" \
))

loss_term_flags=()
for loss_term in ${loss_terms//+/ }; do
  case $loss_term in
    rec) ;;
    lm)
      loss_term_flags+=( \
        --use-language-modeling-head \
      )
      ;;
    ns)
      loss_term_flags+=( \
        --use-next-symbols-head \
      )
      ;;
    *)
      echo "invalid loss term $loss_term" >&2
      exit 1
      ;;
  esac
done

model_dir=$(get_model_dir "$base_dir" "$language" "$architecture" "$loss_terms" "$validation_data" "$trial_no" "$parameter_budget")
python recognizers/neural_networks/train.py \
  --output "$model_dir" \
  --training-data "$language_dir" \
  --validation-data "$validation_data" \
  --architecture "$architecture" \
  "${model_args[@]}" \
  "${model_train_args[@]}" \
  "${loss_term_flags[@]}" \
  --max-epochs 1000 \
  --max-tokens-per-batch "$(random_sample --int 128 4096)" \
  --optimizer Adam \
  --gradient-clipping-threshold 5 \
  --early-stopping-patience 10 \
  --learning-rate-patience 5 \
  --learning-rate-decay-factor 0.5 \
  --examples-per-checkpoint 10000 \
  "${extra_train_args[@]}"
bash recognizers/neural_networks/evaluate.bash "$language_dir" "$model_dir"
