#!/bin/bash

BASE_DIR="experiments_parityfree_intervention"

INTERVENTION_START=2
INTERVENTION_END=1004
INTERVENTION_STEP=100

AUTOMATON="parity_free_hp"

NUM_STATES=4
NUM_ARCS=7
NUM_SYMBOLS=3
DATASET_SIZE=500
ACCEPTANCE_PROB=0.1
size=${DATASET_SIZE}
prob=${ACCEPTANCE_PROB}

# Set interventions as function of dataset size
INTERVENTION_START=$(echo "scale=0; 2" | bc)
INTERVENTION_END=$(echo "scale=0; $DATASET_SIZE * 2" | bc)
INTERVENTION_STEP=$(echo "scale=0; $INTERVENTION_END / 10" | bc)

# Ensure INTERVENTION_STEP is at least 1
if [ $INTERVENTION_STEP -lt 1 ]; then
    INTERVENTION_STEP=1
fi

# Your code to use these variables would go here
echo "Dataset size: $DATASET_SIZE, Acceptance prob: $ACCEPTANCE_PROB"
echo "Interventions: Start=$INTERVENTION_START, End=$INTERVENTION_END, Step=$INTERVENTION_STEP"

echo "STARTING size ${size} with acc prob ${prob}"

DATA_DIR="${BASE_DIR}/data_${size}_${prob}_${INTERVENTION_START}_${INTERVENTION_END}_${INTERVENTION_STEP}"
DATA_OUTPUT_BASE="$DATA_DIR/datasets/${AUTOMATON}"

# Set to 200 for same as in paper {1..200}
for automata_idx in 1; do
    AUTOMATON_OUTPUT_DIR="${DATA_DIR}/${AUTOMATON}/automaton/${automata_idx}"
    TOPOLOGY_OUTPUT_DIR="${DATA_DIR}/${AUTOMATON}/topology/${automata_idx}"

    LIFTED_OUTPUT_DIR="${DATA_DIR}/${AUTOMATON}/lifted/${automata_idx}"
    SAMPLER_OUTPUT_DIR="${DATA_DIR}/${AUTOMATON}/sampler/${automata_idx}"

    VANILLA_LIFTED_OUTPUT_DIR="${DATA_DIR}/${AUTOMATON}/lifted/${automata_idx}/vanilla"
    VANILLA_SAMPLER_OUTPUT_DIR="${DATA_DIR}/${AUTOMATON}/sampler/${automata_idx}/vanilla"

    python src/intervention_sampling/generate_automaton_topology.py \
        --output_dir "${AUTOMATON_OUTPUT_DIR}" \
        --automaton ${AUTOMATON} \
        --accept_prob ${ACCEPTANCE_PROB} \
        --seed ${automata_idx}
        echo "Machine written"

    echo "Preprocessing weighted automaton"
    python src/intervention_sampling/preprocess_weighted_automaton.py \
        --input_dir "${AUTOMATON_OUTPUT_DIR}" \
        --output_dir "${TOPOLOGY_OUTPUT_DIR}" \
        --accept_prob ${ACCEPTANCE_PROB} \
        --seed ${automata_idx} 

    python src/intervention_sampling/lift_weighted_automaton.py \
        --input_dir ${TOPOLOGY_OUTPUT_DIR} \
        --output_dir ${VANILLA_LIFTED_OUTPUT_DIR} \
        --intervention_type vanilla

    echo "Lifted machine (vanilla)"

    python src/intervention_sampling/create_sampler.py \
        --input_dir ${VANILLA_LIFTED_OUTPUT_DIR} \
        --output_dir ${VANILLA_SAMPLER_OUTPUT_DIR} \
        --max_occ_count 5000 \
        --seed 1

    echo "Created sampler (vanilla)"
    echo "Sampling vanilla"
    # First the vanilla datasets for the baselines
    for seed in 1; do    
        train_dir=${DATA_OUTPUT_BASE}/vanilla/train/${automata_idx}/${seed}
        train_out=${train_dir}/main.tok
        val_dir=${DATA_OUTPUT_BASE}/vanilla/validation/${automata_idx}/${seed}
        val_out=${val_dir}/main.tok
        test_dir=${DATA_OUTPUT_BASE}/vanilla/test/${automata_idx}
        test_out=${test_dir}/main.tok

        python src/intervention_sampling/sample_and_prepare_data.py \
            --input_dir ${VANILLA_SAMPLER_OUTPUT_DIR} \
            --dataset_size ${DATASET_SIZE} \
            --num_val 1000 \
            --num_test 2000 \
            --sampling_seed ${seed} \
            --intervention_count 1 \
            --output_type text \
            --training_output ${train_out} \
            --validation_output ${val_out} \
            --test_output ${test_out} 

        # Run the prepare data script as before
        python src/rau/tasks/language_modeling/prepare_data.py \
            --more-data-files "$train_dir"/main.{tok,prepared} \
            --more-data-files "$val_dir"/main.{tok,prepared} \
            --training-data "$test_dir" \
            --never-allow-unk
    done

    echo "Sampling under intervention"

    # symbol
    for INTERVENTION in state; do
        if [[ $INTERVENTION = "symbol" ]]; then
            NUM_TGTS=${NUM_SYMBOLS}
        elif [[ $INTERVENTION = "state" ]]; then
            NUM_TGTS=${NUM_STATES}
        fi

        for TARGET in $(seq 0 $((${NUM_TGTS} - 1))); do

            if [[ $INTERVENTION = "state" && $TARGET -eq 0 ]]; then
                continue
            fi

            if [[ $INTERVENTION = "symbol" ]]; then
                target_args=(--target_symbol "$TARGET" )
            elif [[ $INTERVENTION = "state" ]]; then
                target_args=(--target_state "$TARGET")
            fi

            for SEMIRING in alo binning; do

                if [[ $SEMIRING = "alo" ]]; then
                    semiring_args=(--at_least_once_semiring --intervention_count 1 --validation_num_occurrences 1)
                    semiring_flag=(--at_least_once_semiring)
                elif [[ $SEMIRING = "binning" ]]; then
                    semiring_args=(--intervention_count "${DATASET_SIZE}" --validation_num_occurrences "${DATASET_SIZE}")
                    semiring_flag=()
                fi
            
                LIFTED_OUTPUT_DIR="${DATA_DIR}/${AUTOMATON}/lifted/${automata_idx}/${INTERVENTION}/${TARGET}"
                SAMPLER_OUTPUT_DIR="${DATA_DIR}/${AUTOMATON}/sampler/${automata_idx}/${INTERVENTION}/${TARGET}"

                echo "Lifting weighted automaton for ${INTERVENTION}, target ${TARGET}"
                python src/intervention_sampling/lift_weighted_automaton.py \
                    --input_dir ${TOPOLOGY_OUTPUT_DIR} \
                    --output_dir ${LIFTED_OUTPUT_DIR} \
                    --intervention_type ${INTERVENTION} \
                    "${semiring_args[@]}" \
                    "${target_args[@]}"
                    
                echo "Lifted machine ${INTERVENTION}, target ${TARGET}"
                
                echo "Creating sampler for ${INTERVENTION}, target ${TARGET}"
                python src/intervention_sampling/create_sampler.py \
                    --input_dir ${LIFTED_OUTPUT_DIR} \
                    --output_dir ${SAMPLER_OUTPUT_DIR} \
                    --max_occ_count ${DATASET_SIZE} \
                    --seed ${automata_idx} \
                    "${semiring_flag[@]}"

                echo "Created sampler for ${INTERVENTION}, target ${TARGET}"

                DATA_OUTPUT=${DATA_OUTPUT_BASE}/${SEMIRING}/${AUTOMATON}/${automata_idx}/${INTERVENTION}/${TARGET}
                
                for ((i=INTERVENTION_START; i<=INTERVENTION_END; i+=INTERVENTION_STEP)); do

                    THRESHOLD_PERCENTAGE=98
                    if [[ $SEMIRING = "alo" && ${i} -gt $(( DATASET_SIZE * THRESHOLD_PERCENTAGE / 100 )) ]]; then
                        i=$DATASET_SIZE
                    fi

                    train_dir=${DATA_OUTPUT}/train/${i}
                    train_out=${train_dir}/main.tok
                    val_dir=${DATA_OUTPUT}/validation/${i}
                    val_out=${val_dir}/main.tok
                    test_dir=${DATA_OUTPUT}/test/
                    test_out=${test_dir}/main.tok

                    echo "Sampling data for ${INTERVENTION}, target ${TARGET}, intervention count ${i}"
                    python src/intervention_sampling/sample_and_prepare_data.py \
                        --input_dir ${SAMPLER_OUTPUT_DIR} \
                        --dataset_size ${DATASET_SIZE} \
                        --num_val ${DATASET_SIZE} \
                        --num_test 2000 \
                        --sampling_seed ${i} \
                        --intervention_count ${i} \
                        --output_type text \
                        --training_output ${train_out} \
                        --validation_output ${val_out} \
                        --test_output ${test_out} 

                    # Run the prepare data script as before
                    python src/rau/tasks/language_modeling/prepare_data.py \
                        --more-data-files "$train_dir"/main.{tok,prepared} \
                        --more-data-files "$val_dir"/main.{tok,prepared} \
                        --training-data "$test_dir" \
                        --never-allow-unk

                    if [[ $SEMIRING = "alo" && ${i} -gt $(( DATASET_SIZE * THRESHOLD_PERCENTAGE / 100 )) ]]; then
                        break
                    fi

                done
            done
        done
    done
done
