#!/usr/bin/env bash
set -euo pipefail
. recognizers/functions.bash

: "${LANGUAGE_SAMPLING_SEED:=123456789}"    
: "${STRING_SAMPLING_SEED:=123456789}" 
: "${RESTART:=0}"       
: "${MAX_RESTARTS:=20}" 

ARGS=("$@")

usage() {
  echo "Usage: $0 <base-directory> <language> <device> <max-string-length>

Perform all preprocessing and data generation from scratch for a particular
language. The name of the language determines whether the language is randomly
sampled and, if so, what kind of language class it is sampled from. The results
of this script are always deterministic; an integer in the language name can be
used to distinguish randomly sampled languages generated with different seeds.

  <base-directory>
    Directory under which all datasets and models are stored.
  <language>
    Name of the language to prepare.
    In the following, <n> is an integer that distinguishes different randomly
    sampled languages.
    Choices:
    - random-regular-<n>
      A randomly sampled regular language.
    - random-context-free-<n>
      A randomly sampled context-free language.
    - random-podfa-<n>
      A randomly sampled partially-ordered DFA language.
    - random-star-free-<n>
      A randomly sampled star-free language.
    - The name of a hand-picked language accepted by
      recognizers/string_sampling/sample_dataset.py.
  <device>
    Either 'cpu' or 'gpu'.
  <max-string-length>
    The maximum length of the test sequence.
"
}

base_dir=${1-}
language=${2-}
device=${3-}
max_string_length=${4-}
if ! shift 4; then
  usage >&2
  exit 1
fi

case $device in
  cpu) device=cpu ;;
  gpu) device=cuda ;;
  *)
    usage >&2
    exit 1
    ;;
esac

language_dir=$(get_language_dir "$base_dir" "$language")
mkdir -p "$language_dir"

is_hand_picked_dfa() {
  local language=$1
  for name in \
    all-strings \
    empty-set \
    repeat-01 \
    even-pairs \
    modular-arithmetic \
    parity \
    cycle-navigation \
    first
  do
    if [[ $language = $name ]]; then
      return 0
    fi
  done
  if [[ $language =~ ^dyck-[0-9]+-[0-9]+$ ]]; then
    return 0
  fi
  return 1
}

if [[ $language =~ ^random-(regular|context-free|podfa|star-free)-([0-9]+)$ ]]; then
  language_class=${BASH_REMATCH[1]}
  index=${BASH_REMATCH[2]}
  is_random=true
  do_prepare_sampler=true
elif [[ $language =~ ^random-(regular|context-free|podfa|star-free)-([0-9]+)-mean-(num-states|alphabet-size|num-variables)-([0-9]+)$ ]]; then
  language_class=${BASH_REMATCH[1]}
  index=${BASH_REMATCH[2]}
  is_random=true
  do_prepare_sampler=true
elif is_hand_picked_dfa "$language"; then
  is_random=false
  do_prepare_sampler=true
else
  do_prepare_sampler=false
fi

if $do_prepare_sampler; then
  language_file=$language_dir/language.pt
  sampler_file=$language_dir/sampler.pt
  echo "writing $language_file"
  if $is_random; then
    python recognizers/language_sampling/sample_random_language.py \
      --language-class "$language_class" \
      --random-seed "$((LANGUAGE_SAMPLING_SEED + index))" \
      --output "$language_file" \
      "$@" || { 
        echo "error in language sampling. restarting..(attempt $((RESTART+1))/$MAX_RESTARTS)";
        if (( RESTART + 1 >= MAX_RESTARTS )); then
          echo "too many restarts, exiting."
          exit 1
        fi
        export RESTART="$((RESTART + 1))"
        export LANGUAGE_SAMPLING_SEED="$((LANGUAGE_SAMPLING_SEED + index * 100 + 1))"
        export STRING_SAMPLING_SEED="$((STRING_SAMPLING_SEED + index * 100 + 1))"
        exec /usr/bin/env bash "$0" "${ARGS[@]}"; 
        }
  else
    python recognizers/hand_picked_languages/save_automaton.py \
      --name "$language" \
      --output "$language_file"
  fi
  echo "writing $sampler_file"
  python recognizers/string_sampling/prepare_sampler.py \
    --input "$language_file" \
    --output "$sampler_file" \
    --max-length "$max_string_length" \
    --device "$device"
  sample_dataset_input=(--sampler "$sampler_file")
else
  sample_dataset_input=(--language "$language")
fi

python recognizers/dataset_generation/generate_datasets.py \
  "${sample_dataset_input[@]}" \
  --random-seed "$STRING_SAMPLING_SEED" \
  --output "$language_dir" \
  --skip-test-edit-distance \
  --device cpu \
  --max-length "$max_string_length" || { 
        echo "error in language sampling. restarting..(attempt $((RESTART+1))/$MAX_RESTARTS)";
        if (( RESTART + 1 >= MAX_RESTARTS )); then
          echo "too many restarts, exiting."
          exit 1
        fi
        export RESTART="$((RESTART + 1))"
        export LANGUAGE_SAMPLING_SEED="$((LANGUAGE_SAMPLING_SEED + index * 100 + RESTART))"
        export STRING_SAMPLING_SEED="$((STRING_SAMPLING_SEED + index * 100 + RESTART))"
        exec /usr/bin/env bash "$0" "${ARGS[@]}"; 
        }
bash recognizers/neural_networks/prepare_language.bash "$base_dir" "$language"
echo "data generation completed successfully!"
echo "Timestamp end: $(date)"
