set -euo pipefail

. intervention_sampling/functions.bash

usage() {
  echo "Usage: $0 <base-directory> <training-data> <test-data> <architecture> \\
  <trial-no> [--no-progress]

Train and evaluate a neural network on a language.
"
}

random_sample() {
  python intervention_sampling/neural_networks/random_sample.py "$@"
}

base_dir=${1-}
training_data=${2-}
test_data=${3-}
architecture=${4-}
trial_no=${5-}
if ! shift 5; then
  usage >&2
  exit 1
fi
progress_args=("$@")

training_data_dir=$(get_dataset_dir "$base_dir" "$training_data")
test_data_dir=$(get_dataset_dir "$base_dir" "$test_data")

model_flags=($( \
  python intervention_sampling/neural_networks/get_architecture_args.py \
    --architecture "$architecture" \
    --parameter-budget 64000 \
    --vocabulary-file "$test_data_dir"/main.vocab \
))

model_dir=$(get_model_dir "$base_dir" "$training_data" "$architecture" "$trial_no")
python rau/tasks/language_modeling/train.py \
  --output "$model_dir" \
  --training-data "$training_data_dir" \
  --validation-data-file "$training_data_dir"/datasets/validation/main.prepared \
  --vocabulary-file "$test_data_dir"/main.vocab \
  --architecture "$architecture" \
  "${model_flags[@]}" \
  --init-scale 0.1 \
  --max-epochs 1000 \
  --max-tokens-per-batch "$(random_sample --int 128 4096)" \
  --optimizer Adam \
  --initial-learning-rate "$(random_sample --log 0.0001 0.01)" \
  --gradient-clipping-threshold 5 \
  --early-stopping-patience 10 \
  --learning-rate-patience 5 \
  --learning-rate-decay-factor 0.5 \
  --examples-per-checkpoint 10000 \
  "${progress_args[@]}"
bash intervention_sampling/neural_networks/evaluate.bash "$model_dir" "$test_data_dir"
