#!/bin/bash
set -euo pipefail
. intervention_sampling/functions.bash

usage() {
  echo "Usage: $0 <base-directory> <automaton> <intervention> <num-occurrences>"
}

base_dir=${1-}
automaton=${2-}
intervention=${3-}
num_occurrences=${4-}

if ! shift 4; then
  usage >&2
  exit 1
fi

if [[ $automaton =~ ^(.*)/([0-9]+)$ ]]; then
  automaton_topology=${BASH_REMATCH[1]}
  weight_setting=${BASH_REMATCH[2]}
else
  echo "error: invalid automaton: $automaton" >&2
  exit 1
fi

if [[ $automaton_topology = example ]]; then
  # For the example automaton, derive the random seed from the weight setting
  topology_seed=$((123))
  weight_seed=$((123 + weight_setting))
  automaton_args=(--automaton_name canonical_parity)
elif [[ $automaton_topology =~ ^random-([0-9]+)$ ]]; then
  # For random topology, use the number in the topology name as the topology seed
  # and the weight_setting as an offset for the weight seed
  topology_seed=${BASH_REMATCH[1]}
  weight_seed=$((topology_seed + weight_setting))
  automaton_args=()
else
  echo "error: invalid automaton topology: $automaton_topology" >&2
  exit 1
fi

# Derive intervention and sampling seeds
intervention_seed=$((weight_seed + 10))
sampling_seed=$((weight_seed + 20))

if [[ $intervention =~ ^(symbol|state|transition|vanilla)-(.+)$ ]]; then
  intervention_type=${BASH_REMATCH[1]}
  target=${BASH_REMATCH[2]}
  
  # Map 'transition' to 'arc' for compatibility with our scripts
  if [[ $intervention_type = "transition" ]]; then
    intervention_type="arc"
  fi
  
  # Setup target args based on intervention type
  if [[ $intervention_type = "symbol" ]]; then
    target_args=(--target_symbol "$target")
  elif [[ $intervention_type = "state" ]]; then
    target_args=(--target_state "$target")
  elif [[ $intervention_type = "arc" ]]; then
    target_args=(--target_transition "$target")
  else
    target_args=()
  fi
  
  intervention_args=(
    --intervention_type "$intervention_type"
    "${target_args[@]}"
  )
else
  echo "error: invalid intervention: $intervention" >&2
  exit 1
fi

# Setup directories
dataset_dir=$(get_dataset_dir "$base_dir" "$automaton")
train_dir=$dataset_dir/train/$intervention/$num_occurrences
test_dir=$dataset_dir/test
validation_dir=$train_dir/datasets/validation
artifact_dir=$dataset_dir/artifacts
mkdir -p "$train_dir" "$validation_dir" "$test_dir" "$artifact_dir"

# Step 1: Generate automaton topology parameters
# This creates the initial automaton structure
python intervention_sampling/generate_automaton_topology.py \
  --topology_seed "$topology_seed" \
  --accept_prob 0.05 \
  --num_states 20 \
  --num_symbols 10 \
  "${automaton_args[@]}" \
  --output_dir "$artifact_dir/automaton"

# Step 2: Preprocess the weighted automaton - add intervention targets
# This prepares the automaton for the specific intervention type
python intervention_sampling/preprocess_weighted_automaton.py \
  --input_dir "$artifact_dir/automaton" \
  --output_dir "$artifact_dir/lifted_automaton" \
  --weight_seed "$weight_seed" \
  --max_occ_count "$((num_occurrences + 1))" \
  "${intervention_args[@]}"

# Step 3: Lift the weighted automaton and create the sampler
# This creates a sampler ready to generate samples
python intervention_sampling/lift_weighted_automaton.py \
  --input_dir "$artifact_dir/lifted_automaton" \
  --output_dir "$artifact_dir/intervened_sampler" \
  --intervention_seed "$intervention_seed" \
  --intervention_count "$num_occurrences" \
  --validation_num_occurrences "$((num_occurrences / 10))"

# Step 4: Sample and prepare data
# This generates the training, validation, and test samples and prepares the output files
python intervention_sampling/sample_and_prepare_data.py \
  --input_dir "$artifact_dir/intervened_sampler" \
  --output_dir "$dataset_dir" \
  --dataset_size 10000 \
  --num_val 1000 \
  --num_test 1000 \
  --sampling_seed "$sampling_seed" \
  --automaton_name "$(basename "$automaton_topology")" \
  --num_states 20 \
  --num_symbols 10 \
  "${intervention_args[@]}" \
  --intervention_count "$num_occurrences" \
  --output_type text \
  --training_output "$train_dir"/main.tok \
  --validation_output "$validation_dir"/main.tok \
  --test_output "$test_dir"/main.tok

# Run the prepare data script as before
python rau/tasks/language_modeling/prepare_data.py \
  --more-data-files "$train_dir"/main.{tok,prepared} \
  --more-data-files "$validation_dir"/main.{tok,prepared} \
  --training-data "$test_dir" \
  --never-allow-unk


echo "Pipeline completed successfully!"