#!/usr/bin/env bash
set -euo pipefail

#
# train.sh — training+eval+KL for experiments_simple_intervention
#

#### CONFIGURATION ####
BASE_DIR="experiments_parityfree_intervention"
AUTOMATON="parity_free_hp"

DATASET_SIZE=500
POSTFIX="_${DATASET_SIZE}_0.1_2_1000_100"

DATA_DIR="${BASE_DIR}/data${POSTFIX}/datasets/${AUTOMATON}"
MODELS_DIR="${BASE_DIR}/models${POSTFIX}"

INTERVENTION_START=2
INTERVENTION_END=1000
INTERVENTION_STEP=100

NUM_STATES=4
NUM_ARCS=7
NUM_SYMBOLS=3

NUM_MODELS_TO_TRAIN=1  # Set to 10 to replicate paper

MAX_JOBS=8   #

#### HELPERS ####
random_sample() {
  python src/intervention_sampling/neural_networks/random_sample.py "$@"
}


export BASE_DIR
export NUM_STATES
export NUM_SYMBOLS
export AUTOMATON
export DATA_DIR
export MODELS_DIR
export INTERVENTION_START
export INTERVENTION_END
export INTERVENTION_STEP
export MAX_JOBS

run_single_job() {
  local TYPE="$1"
  local SEMIRING="$2"
  local AM_IDX="$3"
  local INTERVENTION="$4"
  local TARGET="$5"
  local ID="$6"
  local ARCH="$7"
  local MSEED="$8"

  if [[ $TYPE == "vanilla" ]]; then
    local seed="$ID"
    train_dir="${DATA_DIR}/vanilla/train/${AM_IDX}/${seed}"
    val_dir="${DATA_DIR}/vanilla/validation/${AM_IDX}/${seed}"
    test_dir="${DATA_DIR}/vanilla/test/${AM_IDX}"
    id="$seed"
  else
    local ic="$ID"
    base="${DATA_DIR}/${SEMIRING}/${AUTOMATON}/${AM_IDX}/${INTERVENTION}/${TARGET}"
    train_dir="${base}/train/${ic}"
    val_dir="${base}/validation/${ic}"
    test_dir="${base}/test"
    id="$ic"
  fi

  if [[ ! -d "$train_dir" || ! -s "$train_dir/main.tok" ]]; then
    echo "⚠️  [$TYPE] no training data in $train_dir, skipping."
    return 0
  fi

  echo "➡️  [$TYPE] arch=$ARCH | am_idx=$AM_IDX | id=$id | model_seed=$MSEED | target=${TARGET:-n/a}"

  # jitter to avoid races
  sleep "$(awk -v min=0.1 -v max=5 'BEGIN{srand(); print min+rand()*(max-min)}')"

  # get arch‑specific flags
  read -r -a ARCH_ARGS <<< "$(
    python src/intervention_sampling/neural_networks/get_architecture_args.py \
      --architecture "$ARCH" \
      --parameter-budget 128000 \
      --vocabulary-file "${test_dir}/main.vocab"
  )"

  # fixed hyperparams (or swap in random_sample calls)
  max_tokens=256  # A reasonable middle value
  lr=0.01  # A typical learning rate

  rel_path="${SEMIRING}/${AM_IDX}/${INTERVENTION}/${TARGET}/train/${ID}"
  output_dir="${MODELS_DIR}/${ARCH}/${TYPE}/${rel_path}/${MSEED}"
  
  # prepare model output dir
  #output_dir="${MODELS_DIR}/${ARCH}/${TYPE}/${SEMIRING}/${INTERVENTION}/${AM_IDX}/${id}/${MSEED}"

  # Run the prepare data script as before
  python src/rau/tasks/language_modeling/prepare_data.py \
        --more-data-files "$train_dir"/main.{tok,prepared} \
        --more-data-files "$val_dir"/main.{tok,prepared} \
        --training-data "$test_dir" \
        --never-allow-unk

  # ─── TRAIN ────────────────────────────────────────────────────────────
  python src/rau/tasks/language_modeling/train_kl.py \
    --training-data-file   "${train_dir}/main.prepared" \
    --validation-data-file "${val_dir}/main.prepared" \
    --vocabulary-file      "${test_dir}/main.vocab" \
    --output               "${output_dir}" \
    --architecture         "${ARCH}" \
    "${ARCH_ARGS[@]}" \
    --init-scale 0.1 \
    --max-epochs 1000 \
    --max-tokens-per-batch "${max_tokens}" \
    --optimizer Adam \
    --initial-learning-rate "${lr}" \
    --gradient-clipping-threshold 5 \
    --early-stopping-patience 100 \
    --learning-rate-patience 5 \
    --learning-rate-decay-factor 0.5 \
    --examples-per-checkpoint 1000 \
    --no-progress \
    --automaton ${train_dir}/machine.pkl \
    --device cpu

  # ─── EVAL ─────────────────────────────────────────────────────────────
  eval_dir="${output_dir}/eval"
  mkdir -p "${eval_dir}"
  python src/intervention_sampling/neural_networks/evaluate.py \
    --batching-max-tokens 1024 \
    --load-model "${output_dir}" \
    --input-file "${test_dir}/main.prepared" \
    --output    "${eval_dir}"

  # ─── KL DECOMP ────────────────────────────────────────────────────────
  python src/intervention_sampling/evaluate_kl.py \
    --model_logprobs "${eval_dir}/token-negative-log-probabilities.pt" \
    --automaton       "${train_dir}/machine.pkl" \
    --arcs            "${test_dir}/arcs.txt" \
    --vocab_file      "${test_dir}/main.vocab"

  echo "✅ Completed [$TYPE] arch=$ARCH | am_idx=$AM_IDX | id=$id | model_seed=$MSEED"
}

export -f random_sample run_single_job

#### BUILD & SHUFFLE JOB LIST ####
JOB_LIST="job_list.txt"
rm -f "${JOB_LIST}"

# # — vanilla jobs: fields = TYPE SEMIRING AM_IDX INTERVENTION TARGET ID ARCH MSEED
#  
# Change to {1..200}
for AM_IDX in 1; do
  for seed in 1; do
    for ARCH in lstm transformer; do
      # {1..10}
      for MSEED in 1; do
        echo "vanilla none ${AM_IDX} none none ${seed} ${ARCH} ${MSEED}" >> "${JOB_LIST}"
      done
    done
  done
done

# # — intervention jobs
# alo
for semiring in alo binning; do
  #  {1..80}
  for AM_IDX in 1; do
    # symbol
    for intervention in state; do
      if [[ $intervention == symbol ]]; then
        NUM_TGTS=${NUM_SYMBOLS}
      else
        NUM_TGTS=${NUM_STATES}
      fi
      for target in $(seq 0 $((NUM_TGTS-1))); do
        [[ $intervention == state && $target -eq 0 ]] && continue
        for ic in $(seq "${INTERVENTION_START}" "${INTERVENTION_STEP}" "${INTERVENTION_END}"); do
          if [[ $semiring = "alo" && ${ic} -gt ${DATASET_SIZE} ]]; then
              ic=${DATASET_SIZE}
          fi

          for ARCH in lstm transformer; do
            # {1..10}
            for MSEED in 1; do
              echo "intervention ${semiring} ${AM_IDX} ${intervention} ${target} ${ic} ${ARCH} ${MSEED}" \
                >> "${JOB_LIST}"
            done
          done

          if [[ $semiring = "alo" && ${ic} -gt ${DATASET_SIZE} ]]; then
              break
          fi
        done
      done
    done
  done
done

echo "🔀 Shuffling jobs..."
sort -R "${JOB_LIST}" -o "${JOB_LIST}"


TOTAL=$(wc -l < "${JOB_LIST}")
echo "📋 Total jobs to run: ${TOTAL}"


mkdir -p logs
#### LAUNCH ####
echo "🚀 Launching up to ${MAX_JOBS} concurrent jobs..."
parallel --progress --jobs "${MAX_JOBS}" --colsep ' ' --joblog parallel_job.log \
  --will-cite 'run_single_job {1} {2} {3} {4} {5} {6} {7} {8} > logs/job_{#}.out 2>&1' :::: "${JOB_LIST}"

echo "🎉 All training jobs done."
