#!/bin/bash

# both Classifier and Feature Selector are SetTransformer

# Paths
DATA_DIR="path/to/your/datasets/clevr-hans/encodings/unconfounded/CLEVR_Hans3_4/one_hot_padded_encodings_unconfounded_CLEVR_Hans3_4"
PRETRAINED_PATH="path/to/your/merlinarthur-ncb-results/checkpoints/regular/unconfounded/regular_SetTransformer_on_one_hot_padded_seed42_likely-sweep-1/best_model.pth"
RES_DIR="path/to/your/merlinarthur-ncb-results/checkpoints/learnable_fs/unconfounded/fs_SetTransf_classifier_SetTransf"

# Required arguments
ENC_TYPE="one_hot_padded"
MODEL="SetTransformer"  
FS_MODEL="SetTransformer"
MASK_SIZE=4

python src/main.py \
    --epochs 50 \
    --approach "learn_fs" \
    --lr 0.001 \
    --seed 1 \
    --data_dir $DATA_DIR \
    --enc_type $ENC_TYPE \
    --res_dir $RES_DIR \
    --batch_size 512 \
    --num_workers 4 \
    --model $MODEL \
    --n_heads 4 \
    --set_transf_hidden 128 \
    --pretrained_model \
    --pretrained_path $PRETRAINED_PATH \
    --mask_size $MASK_SIZE \
    --gamma 1 \
    --lr_merlin 0.001 \
    --lr_morgana 0.001 \
    --fs_model $FS_MODEL \
    --fs_hidden_dim 256 \
    --fs_dropout 0.3 \
    --fs_n_heads 4 \
    --weight_decay_merlin 0.00001 \
    --weight_decay_morgana 0.00001 \
    --weight_decay 0.0001 \
    --l1_penalty_coefficient 0.1 \
    --compute_avg_occ \
    #--wandb \
    #--save_model \

