#!/bin/bash
cd "$(dirname "$0")"/..
common_command="python -m train ++trainer.max_epochs=100 ++dataset.generator.debug=False ++dataset.generator.h=1 ++model.dropout=0.2 ++encoder._name_=absolute_time ++dataset.num_seq=200000"

# Function to derive wandb.name from experiment name, encoder name, and additional arguments
get_wandb_name() {
    experiment_file=$1
    encoder_name=$2
    additional_args=$3

    # Remove file extension if present
    experiment_name=${experiment_file%.*}

    declare -A encoder_name_map=(
        ["timeseries_synthetics"]="enc_base"
        ["absolute_time"]="abs_time"
        ["positional_linear"]="pos"
        ["calendar_positional_linear"]="cpl"
    )

    # Get the shorthand for the encoder name
    if [[ ${encoder_name_map[$encoder_name]} ]]; then
        encoder_short_name=${encoder_name_map[$encoder_name]}
    else
        encoder_short_name=$encoder_name
    fi
    # Remove the "model.layer." prefix from additional_args
    num_seq=${common_command##*dataset.num_seq=}
    num_seq=${num_seq%% *}
    
    additional_args=${additional_args//++model.layer./}
    additional_args=${additional_args//=/_}
    
    wandb_name="${experiment_name}_${encoder_short_name}_num_seq_${num_seq}_${additional_args}"
    echo $wandb_name
}

run_command() {
    cuda_device=$1
    experiment_file=$2
    additional_args=$3

    # Extract encoder name from the common_command
    encoder_name=${common_command##*encoder._name_=}
    encoder_name=${encoder_name%% *}

    wandb_name=$(get_wandb_name "$experiment_file" "$encoder_name" "$additional_args")
    CUDA_VISIBLE_DEVICES=$cuda_device $common_command experiment=$experiment_file ++wandb.name=$wandb_name $additional_args &
}

run_command 0 "timeseries/ts_lc.yaml" "++model.layer.kernel_dropout=0.2"
run_command 1 "timeseries/ts_id.yaml" "++model.layer.T=20"
run_command 4 "timeseries/ts_id.yaml" "++model.layer.T=1"
run_command 5 "timeseries/ts_id.yaml" "++model.layer.T=0"
run_command 6 "timeseries/ts_mha.yaml"
run_command 3 "timeseries/ts_lc.yaml" "++model.layer.kernel_dropout=0.3"
run_command 2 "timeseries/ts_ff.yaml" "++model.layer.T=1"
run_command 7 "timeseries/ts_ff.yaml" "++model.layer.T=20"