#!/usr/bin/env bash
###############################################################################
# train_array.sbatch  —  massive training sweep on Slurm (conda + safe job list)
###############################################################################
#SBATCH --partition= 
#SBATCH --job-name=parityfree
#SBATCH --ntasks=1              
#SBATCH --cpus-per-task=16          
#SBATCH --exclusive                 # exclusive access to the node
#SBATCH --mem-per-cpu=4G              
#SBATCH --time=6:00:00
#SBATCH --array=0-250                
#SBATCH --output=slurm_logs_train/slurm_%A_%a.out
#SBATCH --error=slurm_logs_train/slurm_%A_%a.err
###############################################################################

set -euo pipefail

###############################################################################
# 0.  WORK DIR, CONDA, MODULES
###############################################################################


# ADAPT TO YOUR SETTING


###############################################################################
# 1.  CONFIGURATION  (unchanged from the original script)
###############################################################################
BASE_DIR="experiments_parityfree_intervention"
AUTOMATON="parity_free_hp"

DATASET_SIZE=500
POSTFIX="_${DATASET_SIZE}_0.1_2_1000_100"

DATA_DIR="${BASE_DIR}/data${POSTFIX}/datasets/${AUTOMATON}"
MODELS_DIR="${BASE_DIR}/models${POSTFIX}"

INTERVENTION_START=2
INTERVENTION_END=603
INTERVENTION_STEP=100

NUM_STATES=4
NUM_ARCS=7
NUM_SYMBOLS=3

# Job list configuration
JOB_LIST="job_list.txt"              # stays in the repo dir above
LOCK_FILE="${JOB_LIST}.lock"
COMPLETE_MARKER="${JOB_LIST}.complete"  # New completion marker

###############################################################################
# 2.  HELPERS  —  unchanged
###############################################################################
random_sample() {
  python src/intervention_sampling/neural_networks/random_sample.py "$@"
}

run_single_job() {

  export OMP_NUM_THREADS=1
  export MKL_NUM_THREADS=1
  export OPENBLAS_NUM_THREADS=1
  export NUMEXPR_NUM_THREADS=1
  export VECLIB_MAXIMUM_THREADS=1

  local TYPE="$1"
  local SEMIRING="$2"
  local AM_IDX="$3"
  local INTERVENTION="$4"
  local TARGET="$5"
  local ID="$6"
  local ARCH="$7"
  local MSEED="$8"

  if [[ $TYPE == "vanilla" ]]; then
    local seed="$ID"
    train_dir="${DATA_DIR}/vanilla/train/${AM_IDX}/${seed}"
    val_dir="${DATA_DIR}/vanilla/validation/${AM_IDX}/${seed}"
    test_dir="${DATA_DIR}/vanilla/test/${AM_IDX}"
    id="$seed"
  else
    local ic="$ID"
    base="${DATA_DIR}/${SEMIRING}/${AUTOMATON}/${AM_IDX}/${INTERVENTION}/${TARGET}"
    train_dir="${base}/train/${ic}"
    val_dir="${base}/validation/${ic}"
    test_dir="${base}/test"
    id="$ic"
  fi

  rel_path="${SEMIRING}/${AM_IDX}/${INTERVENTION}/${TARGET}/train/${ID}"
  output_dir="${MODELS_DIR}/${ARCH}/${TYPE}/${rel_path}/${MSEED}"
  eval_dir="${output_dir}/eval"
  mkdir -p "${eval_dir}"

  final_kl_file="${eval_dir}/decomposed_kls.json"

  if [[ -f "$final_kl_file" ]]; then
    echo "✅ [$TYPE] Final KL output exists at $final_kl_file, skipping job."
    return 0
  fi

  if [[ ! -d "$train_dir" || ! -s "$train_dir/main.tok" ]]; then
    echo "⚠️  [$TYPE] no training data in $train_dir, skipping."
    return 0
  fi

  echo "➡️  [$TYPE] arch=$ARCH | am_idx=$AM_IDX | id=$id | model_seed=$MSEED | target=${TARGET:-n/a}"

  # jitter to avoid races on the filesystem
  sleep "$(awk -v min=0.1 -v max=5 'BEGIN{srand(); print min+rand()*(max-min)}')"

  # get arch‑specific flags
  read -r -a ARCH_ARGS <<< "$(
    python src/intervention_sampling/neural_networks/get_architecture_args.py \
      --architecture "$ARCH" \
      --parameter-budget 128000 \
      --vocabulary-file "${test_dir}/main.vocab"
  )"

  # fixed hyper‑params (override with random_sample if desired)
  max_tokens=256
  lr=0.01


  if [ ! -f "$train_dir"/main.prepared ]; then
    # ─── PREPARE DATA ──────────────────────────────────────────────────────
    python src/rau/tasks/language_modeling/prepare_data.py \
          --more-data-files "$train_dir"/main.{tok,prepared} \
          --more-data-files "$val_dir"/main.{tok,prepared} \
          --training-data "$test_dir" \
          --never-allow-unk
  fi

  # ─── TRAIN ────────────────────────────────────────────────────────────
  python src/rau/tasks/language_modeling/train_kl.py \
    --training-data-file   "${train_dir}/main.prepared" \
    --validation-data-file "${val_dir}/main.prepared" \
    --vocabulary-file      "${test_dir}/main.vocab" \
    --output               "${output_dir}" \
    --architecture         "${ARCH}" \
    "${ARCH_ARGS[@]}" \
    --init-scale 0.1 \
    --max-epochs 1000 \
    --max-tokens-per-batch "${max_tokens}" \
    --optimizer Adam \
    --initial-learning-rate "${lr}" \
    --gradient-clipping-threshold 5 \
    --early-stopping-patience 100 \
    --learning-rate-patience 5 \
    --learning-rate-decay-factor 0.5 \
    --examples-per-checkpoint 1000 \
    --no-progress \
    --automaton ${train_dir}/machine.pkl \
    --device cpu

  # ─── EVAL ─────────────────────────────────────────────────────────────

  python src/intervention_sampling/neural_networks/evaluate.py \
    --batching-max-tokens 1024 \
    --load-model "${output_dir}" \
    --input-file "${test_dir}/main.prepared" \
    --output    "${eval_dir}"

  # ─── KL DECOMP ────────────────────────────────────────────────────────
  python src/intervention_sampling/evaluate_kl.py \
    --model_logprobs "${eval_dir}/token-negative-log-probabilities.pt" \
    --automaton       "${train_dir}/machine.pkl" \
    --arcs            "${test_dir}/arcs.txt" \
    --vocab_file      "${test_dir}/main.vocab"

  echo "✅ Completed [$TYPE] arch=$ARCH | am_idx=$AM_IDX | id=$id | model_seed=$MSEED"
}

export BASE_DIR NUM_STATES NUM_SYMBOLS AUTOMATON DATA_DIR MODELS_DIR
export INTERVENTION_START INTERVENTION_END INTERVENTION_STEP
export -f random_sample run_single_job

###############################################################################
# 3.  BUILD & SHUFFLE THE JOB LIST - ONLY THE FIRST PROCESS CREATES IT
###############################################################################

# File to track which process is building the job list
BUILDING_MARKER="${JOB_LIST}.building"

# Function to build the job list
build_job_list() {
  # Record our PID to indicate we're building the job list
  echo "$" > "${BUILDING_MARKER}"
  
  echo "🛠  Building job list (PID $)..."
  rm -f "${JOB_LIST}"

  ——— vanilla jobs ————————————————————————————————————————————
  for AM_IDX in {1..200}; do
   for seed in 1; do
     for ARCH in lstm transformer; do
       for MSEED in {1..20}; do
         echo "vanilla none ${AM_IDX} none none ${seed} ${ARCH} ${MSEED}" >> "${JOB_LIST}"
       done
     done
   done
  done

  # ——— intervention jobs ——————————————————————————————————————
  # alo binning
  for semiring in alo; do
    for AM_IDX in {1..80}; do
      for intervention in symbol state; do
        if [[ $intervention == symbol ]]; then
          NUM_TGTS=${NUM_SYMBOLS}
        else
          NUM_TGTS=${NUM_STATES}
        fi
        for target in $(seq 0 $((NUM_TGTS-1))); do
          [[ $intervention == state && $target -eq 0 ]] && continue
          for ic in $(seq "${INTERVENTION_START}" "${INTERVENTION_STEP}" "${INTERVENTION_END}"); do
            if [[ $semiring = "alo" && ${ic} -gt ${DATASET_SIZE} ]]; then
                ic=${DATASET_SIZE}
            fi
            for ARCH in lstm transformer; do
              for MSEED in {1..20}; do
                echo "intervention ${semiring} ${AM_IDX} ${intervention} ${target} ${ic} ${ARCH} ${MSEED}" \
                  >> "${JOB_LIST}"
              done
            done
            if [[ $semiring = "alo" && ${ic} -gt ${DATASET_SIZE} ]]; then
                break
            fi
          done
        done
      done
    done
  done

  echo "🔀 Shuffling job list ..."
  # Use a more robust shuffling method
  shuf "${JOB_LIST}" -o "${JOB_LIST}.shuffled"
  mv "${JOB_LIST}.shuffled" "${JOB_LIST}"
  
  # Count jobs and create completion marker
  local COUNT=$(wc -l < "${JOB_LIST}")
  echo "✅ Job list created with ${COUNT} lines."
  
  # Create completion marker with job count for verification
  echo "${COUNT}" > "${COMPLETE_MARKER}"
  
  # Remove the building marker
  rm -f "${BUILDING_MARKER}"
}

# Check if we're the first task (lowest SLURM_ARRAY_TASK_ID)
if [[ "${SLURM_ARRAY_TASK_ID}" -eq "0" && ! -f "${COMPLETE_MARKER}" ]]; then
  # We're the first task and no completed job list exists - create it
  echo "🏁 This is the first task (ID ${SLURM_ARRAY_TASK_ID}), will create the job list"
  
  # First, check if another process is already building
  if [[ -f "${BUILDING_MARKER}" ]]; then
    BUILDING_PID=$(cat "${BUILDING_MARKER}")
    if kill -0 "${BUILDING_PID}" 2>/dev/null; then
      echo "⏳ Another process (PID ${BUILDING_PID}) is already building the job list, waiting..."
    else
      echo "🚨 Found stale building marker from PID ${BUILDING_PID}, will take over building"
      build_job_list
    fi
  else
    # No other process is building, so we'll do it
    build_job_list
  fi
else
  # We're not the first task or job list already exists
  echo "⏳ Task ${SLURM_ARRAY_TASK_ID} waiting for job list to be built by task 0..."
fi

# All tasks wait for the job list to be fully built before proceeding
while [[ ! -f "${COMPLETE_MARKER}" ]]; do
  echo "⏳ Waiting for job list completion marker... ($(date))"
  sleep 5
  
  # If we're task 0 and no one is building, take over
  if [[ "${SLURM_ARRAY_TASK_ID}" -eq "0" && ! -f "${BUILDING_MARKER}" ]]; then
    echo "🚨 No process is currently building the job list, task 0 will take over"
    build_job_list
  fi
done

# Verify job list exists and get total count
if [[ ! -f "${JOB_LIST}" || ! -f "${COMPLETE_MARKER}" ]]; then
  echo "🚨 ERROR: Job list or completion marker missing after wait loop!"
  exit 1
fi

TOTAL=$(cat "${COMPLETE_MARKER}")
ACTUAL_COUNT=$(wc -l < "${JOB_LIST}")

# Final verification
if [[ "${TOTAL}" != "${ACTUAL_COUNT}" ]]; then
  echo "🚨 ERROR: Job count mismatch! Expected ${TOTAL}, found ${ACTUAL_COUNT}."
  exit 1
fi

echo "📋 Total jobs to run: ${TOTAL} (verified)"

###############################################################################
# 4.  CHUNK DISPATCHING PER SLURM‑ARRAY ELEMENT
###############################################################################
CHUNK_SIZE=1000                                 # 1 000 jobs per array element
START=$(( SLURM_ARRAY_TASK_ID * CHUNK_SIZE ))
END=$(( START + CHUNK_SIZE - 1 ))

if (( START >= TOTAL )); then
  echo "Nothing to do for array task ${SLURM_ARRAY_TASK_ID} – START=$START >= $TOTAL"
  exit 0
fi

echo "🚀 Array task $SLURM_ARRAY_TASK_ID executing jobs $START … $END (max)."

# Each array element runs its slice in parallel on 48 cores:
sed -n "$((START + 1)),$((END + 1))p" "${JOB_LIST}" \
  | parallel --jobs 16 --colsep ' ' \
         --joblog "parallel_job_${SLURM_ARRAY_TASK_ID}.log" \
         --will-cite \
         --results results_logs/${SLURM_ARRAY_TASK_ID}/job_{#}/ \
         run_single_job {1} {2} {3} {4} {5} {6} {7} {8}

echo "🎉 Array task ${SLURM_ARRAY_TASK_ID} finished."
