#!/bin/bash
#SBATCH --partition= 
#SBATCH --job-name=wfsasamp
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=16
#SBATCH --mem=16G
#SBATCH --time=4:00:00
#SBATCH --array=1-999%120  # Process seeds 1-999, with max 120 concurrent jobs
#SBATCH --output=logs/slurm_%A_%a.out
#SBATCH --error=logs/slurm_%A_%a.err
#SBATCH --gres=gpu:1 # Request 1 GPU

# Use the SLURM array task ID as the seed
SEED=${SLURM_ARRAY_TASK_ID}

# Navigate to the project directory
cd ANONYMOUS_FOLDER

# Set up conda
eval "$(conda shell.bash hook)"
conda activate formal

# Training parameters
INTERVENTION_START=50
INTERVENTION_END=2000
INTERVENTION_STEP=200
BASE_DIR="experiments_random_100_10"

NUM_STATES=50
NUM_SYMBOLS=10
ACCEPTANCE_PROB=0.3
NUM_SAMPLES=500

# Create log directory if it doesn't exist
mkdir -p logs

echo "Starting job for SEED=${SEED}"

# Process for the current seed only
AUTOMATON_OUTPUT_DIR="${BASE_DIR}/data/${NUM_STATES}st_${NUM_SYMBOLS}sym/machine/${SEED}"
TOPOLOGY_OUTPUT_DIR="${BASE_DIR}/data/${NUM_STATES}st_${NUM_SYMBOLS}sym/topology/${SEED}"

# Create directories if they don't exist
mkdir -p "${AUTOMATON_OUTPUT_DIR}"
mkdir -p "${TOPOLOGY_OUTPUT_DIR}"

echo "Generating automaton topology for seed ${SEED}"
python src/intervention_sampling/generate_automaton_topology.py \
    --output_dir "${AUTOMATON_OUTPUT_DIR}" \
    --save_automaton \
    --num_states ${NUM_STATES} \
    --num_symbols 10 \
    --accept_prob ${ACCEPTANCE_PROB} \
    --seed ${SEED} \
    --topology_seed ${SEED}

echo "Preprocessing weighted automaton for seed ${SEED}"
python src/intervention_sampling/preprocess_weighted_automaton.py \
    --input_dir "${AUTOMATON_OUTPUT_DIR}" \
    --output_dir "${TOPOLOGY_OUTPUT_DIR}" \
    --accept_prob ${ACCEPTANCE_PROB} \
    --seed ${SEED} 

for INTERVENTION in symbol state; do
    if [[ $INTERVENTION = "symbol" ]]; then
        NUM_TGTS=${NUM_SYMBOLS}
    elif [[ $INTERVENTION = "state" ]]; then
        NUM_TGTS=${NUM_STATES}
    elif [[ $INTERVENTION = "vanilla" ]]; then
        NUM_TGTS=1
    fi


    for TARGET in $(seq 0 $((${NUM_TGTS} - 1))); do

        if [[ $INTERVENTION = "state" && $TARGET -eq 0 ]]; then
            continue
        fi

        if [[ $INTERVENTION = "symbol" ]]; then
            target_args=(--target_symbol "$TARGET" )
        elif [[ $INTERVENTION = "state" ]]; then
            target_args=(--target_state "$TARGET")
        elif [[ $INTERVENTION = "arc" ]]; then
            # Left here until we do arcs
            target_args=(--target_transition "$TARGET")
        fi

        for SEMIRING in alo binning; do

            if [[ $SEMIRING = "alo" ]]; then
                semiring_args=(--at_least_once_semiring --intervention_count 1 --validation_num_occurrences 1)
                semiring_flag=(--at_least_once_semiring)
            elif [[ $SEMIRING = "binning" ]]; then
                semiring_args=(--intervention_count "${NUM_SAMPLES}" --validation_num_occurrences "${NUM_SAMPLES}")
                semiring_flag=()
            fi

            LIFTED_OUTPUT_DIR="${BASE_DIR}/data/${SEMIRING}/${NUM_STATES}st_${NUM_SYMBOLS}sym/lifted/${SEED}/${INTERVENTION}/${TARGET}"
            SAMPLER_OUTPUT_DIR="${BASE_DIR}/data/${SEMIRING}/${NUM_STATES}st_${NUM_SYMBOLS}sym/sampler/${SEED}/${INTERVENTION}/${TARGET}"
            
            # Create directories if they don't exist
            mkdir -p "${LIFTED_OUTPUT_DIR}"
            mkdir -p "${SAMPLER_OUTPUT_DIR}"

            echo "Lifting weighted automaton for ${INTERVENTION}, target ${TARGET}, seed ${SEED}"
            python src/intervention_sampling/lift_weighted_automaton.py \
                --input_dir ${TOPOLOGY_OUTPUT_DIR} \
                --output_dir ${LIFTED_OUTPUT_DIR} \
                --intervention_type ${INTERVENTION} \
                "${semiring_args[@]}" \
                "${target_args[@]}"
                
            echo "Lifted machine (${INTERVENTION}, target ${TARGET}, seed ${SEED})"

            echo "Creating sampler for ${INTERVENTION}, target ${TARGET}, seed ${SEED}"
            python src/intervention_sampling/create_sampler.py \
                --input_dir ${LIFTED_OUTPUT_DIR} \
                --output_dir ${SAMPLER_OUTPUT_DIR} \
                --max_occ_count ${NUM_SAMPLES} \
                --seed ${SEED} \
                "${semiring_flag[@]}"

            echo "Created sampler for ${INTERVENTION}, target ${TARGET}, seed ${SEED}"

            DATA_OUTPUT=${BASE_DIR}/data/datasets/${SEMIRING}/${NUM_STATES}st_${NUM_SYMBOLS}sym/${SEED}/${INTERVENTION}/${TARGET}
            
            for ((i=INTERVENTION_START; i<=INTERVENTION_END; i+=INTERVENTION_STEP)); do
                train_dir=${DATA_OUTPUT}/train/${i}
                train_out=${train_dir}/main.tok
                val_dir=${DATA_OUTPUT}/validation/${i}
                val_out=${val_dir}/main.tok
                test_dir=${DATA_OUTPUT}/test/
                test_out=${test_dir}/main.tok

                # Create directories if they don't exist
                mkdir -p "${train_dir}"
                mkdir -p "${val_dir}"
                mkdir -p "${test_dir}"

                echo "Sampling data for ${INTERVENTION}, target ${TARGET}, seed ${SEED}, intervention count ${i}"
                python src/intervention_sampling/sample_and_prepare_data.py \
                    --input_dir ${SAMPLER_OUTPUT_DIR} \
                    --dataset_size ${NUM_SAMPLES} \
                    --num_val ${NUM_SAMPLES} \
                    --num_test 2000 \
                    --sampling_seed ${i} \
                    --intervention_count ${i} \
                    --output_type text \
                    --training_output ${train_out} \
                    --validation_output ${val_out} \
                    --test_output ${test_out} 

            done
        done
    done
done

echo "Completed job for SEED=${SEED}"
