#!/bin/bash
export CUDA_VISIBLE_DEVICES=2

# nohup bash ./scripts/energy/ETTh1_backbone.sh &

max_workers=1
counter=0

# default parameters
seq_len=96
e_layers=2
d_layers=1
factor=1
n_heads=8
top_k=5
d_model=512
d_ff=2048
batch_size=32
enc_in=7
dataset=ETTh1
root_path='./dataset/ETT-small/'
data_path='ETTh1.csv'
hidden_size=128
SCI=0

# default learning rate: 0.0005
learning_rate=0.0005

for model_name in DLinear PatchTST TimesNet Amplifier; do

if [ "$model_name" == 'PatchTST' ] || [ "$model_name" == 'PatchTST_energy' ]; then
    e_layers=1
    d_layers=1
    factor=3
elif [ "$model_name" == 'DLinear' ] || [ "$model_name" == 'DLinear_energy' ]; then
    e_layers=2
    factor=3
elif [ "$model_name" == 'TimesNet' ] || [ "$model_name" == 'TimesNet_energy' ]; then
    d_model=16
    d_ff=32
    e_layers=2
    d_layers=1
    factor=3
    top_k=5
elif [ "$model_name" == 'Amplifier' ] || [ "$model_name" == 'Amplifier_energy' ]; then
    hidden_size=64
    batch_size=256
    SCI=0
fi

folder='./logs/energy_exp/backbone/'$dataset'/'$model_name'/'
mkdir -p $folder

for pred_len in 96 192 336 720; do

    if [ "$model_name" == 'PatchTST' ] || [ "$model_name" == 'PatchTST_energy' ]; then
        if [ "$dataset" == 'ETTh1' ]; then
            if [ "$pred_len" == 96 ]; then
                n_heads=2
            elif [ "$pred_len" == 192 ]; then
                n_heads=8
            elif [ "$pred_len" == 336 ]; then
                n_heads=8
            elif [ "$pred_len" == 720 ]; then
                n_heads=16
            fi
        fi
    fi

    if [ "$model_name" == 'Amplifier' ] || [ "$model_name" == 'Amplifier_energy' ]; then
        if [ "$pred_len" == 96 ]; then
            hidden_size=64
        elif [ "$pred_len" == 192 ]; then
            hidden_size=512
        elif [ "$pred_len" == 336 ]; then
            hidden_size=512
        elif [ "$pred_len" == 720 ]; then
            hidden_size=512
        fi
    fi

    for seed in 2021; do
        des_1=$dataset'_'$model_name'_'$seq_len'_'$pred_len'_'$seed
        current_time=$(date +'%Y-%m-%d_%H-%M-%S')
        des_2=$des_1'_date_'$current_time

        echo "Running: ${dataset} ${model_name} pred_len=${pred_len} seed=${seed}"
        python -u run.py \
            --task_name long_term_forecast \
            --is_training 1 \
            --model_id ${dataset}_${seq_len}_${pred_len} \
            --model $model_name \
            --data $dataset \
            --root_path $root_path \
            --data_path $data_path \
            --features M \
            --seq_len $seq_len \
            --label_len 48 \
            --pred_len $pred_len \
            --enc_in $enc_in \
            --dec_in $enc_in \
            --c_out $enc_in \
            --e_layers $e_layers \
            --d_layers $d_layers \
            --factor $factor \
            --n_heads $n_heads \
            --seed $seed \
            --d_model $d_model \
            --d_ff $d_ff \
            --top_k $top_k \
            --batch_size $batch_size \
            --learning_rate $learning_rate \
            --hidden_size $hidden_size \
            --SCI $SCI \
            --train_epochs 10 \
            --patience 3 \
            --des $des_1 \
            --itr 1 >$folder$des_2.out 2>&1 &
        echo "Done: ${dataset} ${model_name} pred_len=${pred_len} seed=${seed}"
        echo "----------------------------------------"

        ((counter++))
        if (( counter % max_workers == 0 )); then
        wait
        fi

    done
done
done