#!/bin/bash

seeds=(1234 5425 3914)
dataset="${1:-mnist}"       # Default to 'mnist' if $1 is not provided
num_train="${2:-1000}"      # Default to '1000' if $2 is not provided
m="${3:-5}"                 # Default to '5' if $3 is not provided
task="${4:-msum}"
framework="${5:-scallop}"  # Default to 'pytorch' if $5 is not provided
provenance="diffaddmultprob"

if [ "$dataset" == "mnist" ]; then
    script="sum_n_algo1_structured_pruning.py"
elif [ "$dataset" == "cifar10" ]; then
    script="cifar_sum_n_algo1_structured_pruning.py"
else
    echo "Invalid dataset. Please use 'mnist' or 'cifar10'."
    exit 1
fi

# if framework is not scallop
if [ "$framework" != "scallop" ]; then
    script="${framework}_$script"
    provenance="damp"
fi 

# assert that $task can be either "msum" or "mmax"
if [ "$task" != "msum" ] && [ "$task" != "mmax" ]; then
    echo "Invalid task. Please use 'msum' or 'mmax'."
    exit 1
fi

# batch size is 64 if msum else 16
if [ "$task" == "msum" ]; then
    batch_size=64
else
    if [ "$framework" == "scallop" ]; then
        provenance="difftopkproofs"
    fi
    # if m is 4, batch size is 8 else 64
    if [ "$m" -eq 4 ] && [ "$dataset" -eq "cifar10" ]; then
        batch_size=4
    else
        batch_size=64
    fi
fi


mkdir -p logs_dolphin

for seed in "${seeds[@]}"; do
    mkdir -p logs_${framework}
    # # first with purification
    log_file="logs_${framework}/${dataset}_seed${seed}_num_train${num_train}_m${m}_${task}_purification.log"
    command="$script --sum-n=$m --provenance=${provenance} --device cuda --n-epochs 30  --gpu 0 --num-training-samples $num_train --with_purification --dataset $task --seed $seed --batch-size-train $batch_size --mock_proximity --structure-k 1"
    echo "Running for seed $seed with num_train $num_train and m $m storing in $log_file"
    echo "Command: $command"
    (python $command | tee "$log_file" 2>&1 ) &

    # then with purification without GT
    log_file="logs_${framework}/${dataset}_seed${seed}_num_train${num_train}_m${m}_${task}_purification_no_gt.log"
    command="$script --sum-n=$m --provenance=${provenance} --device cuda --n-epochs 30  --gpu 0 --num-training-samples $num_train --with_purification --dataset $task --seed $seed --batch-size-train $batch_size --structure-k 1 --learning-rate 1e-4"
    echo "Running for seed $seed with num_train $num_train and m $m storing in $log_file"
    echo "Command: $command"
    (python $command | tee "$log_file" 2>&1 ) &

    # # then without purification
    log_file="logs_${framework}/${dataset}_seed${seed}_num_train${num_train}_m${m}_${task}.log"
    command="$script --sum-n=$m --provenance=${provenance} --device cuda --n-epochs 30  --gpu 0 --num-training-samples $num_train --dataset $task --seed $seed --batch-size-train $batch_size"
    echo "Running for seed $seed with num_train $num_train and m $m storing in $log_file"
    echo "Command: $command"
    (python $command | tee "$log_file" 2>&1 ) &
    wait
done
wait
echo "All jobs completed."
