1. LF-CBM training (supervised):

CIFAR10:
python train_cbm.py --concept_set data/concept_sets/cifar10_filtered.txt

CIFAR100:
python train_cbm.py --dataset cifar100 --concept_set data/concept_sets/cifar100_filtered.txt --save_dir saved_models_gpt4

CIFAR100-20:
python train_cbm.py --dataset cifar100-20 --concept_set data/concept_sets/cifar100-20_filtered_gpt4.txt --save_dir saved_models_gpt4

CUB200:
python train_cbm.py --dataset cub --backbone resnet18_cub --concept_set data/concept_sets/cub_filtered.txt --feature_layer features.final_pool --clip_cutoff 0.26 --n_iters 5000 --lam 0.0002

Places365:
python train_cbm.py --dataset places365 --backbone resnet50 --concept_set data/concept_sets/places365_filtered.txt --clip_cutoff 0.28 --n_iters 80 --lam 0.0003

ImageNet:
python train_cbm.py --dataset imagenet --backbone resnet50 --concept_set data/concept_sets/imagenet_filtered.txt --clip_cutoff 0.28 --n_iters 80 --lam 0.0001


CEIR:

1. Train CBM + VAE:

CIFAR10:
python -u train_cbm.py --dataset cifar10 --concept_set data/concept_sets/cifar10_filtered_gpt4.txt \
--backbone clip_ViT-L/14 --train_vae True --vae_train_set both --vae_hidden_dim 256 --vae_latent_dim 128 \
--vae_epochs 450 --save_vae True --device cuda

CIFAR100-20:
python -u train_cbm.py --dataset cifar100-20 --concept_set data/concept_sets/cifar100-20_filtered_gpt4.txt \
--backbone clip_RN50 --train_vae True --vae_train_set both --vae_hidden_dim 256 --vae_latent_dim 128 \
--vae_epochs 450 --save_vae True --device cuda

STL10:
python -u train_cbm.py --dataset stl10 --concept_set data/concept_sets/stl10_filtered_gpt4.txt \
--backbone clip_ViT-B/16 --train_vae True --vae_train_set both --vae_hidden_dim 256 --vae_latent_dim 128 \
--vae_epochs 450 --save_vae True --device cuda

STL10-unlabeled (add additional 100k unlabeled image):
python -u train_cbm.py --dataset stl10-unlabeled --concept_set data/concept_sets/stl10_filtered_gpt4.txt \
--backbone clip_ViT-B/16 --train_vae True --vae_train_set both --vae_hidden_dim 256 --vae_latent_dim 128 \
--vae_epochs 450 --save_vae True --device cuda

1.1 Remove Test set during VAE training
--vae_train_set train

1.2 Remove class-related concepts
--concept_set data/concept_sets/gpt4_remove_class_concept/{dataset}_filtered_gpt4.txt


2. Train VAE only:

cifar10:
python -u train_vae.py --dataset cifar10 --concept_load_path saved_models_gpt4/cifar10/{model_dir}/cbm_clip_{backbone} \
--base_save_path saved_models_gpt4/cifar10/{model_dir} --vae_train_set both --vae_hidden_dim 256 --vae_latent_dim 128 \
--vae_epochs 450 --save_vae True --device cuda --multiprocess True

imagenet:
python -u train_vae.py --dataset imagenet --concept_load_path saved_models_gpt4/imagenet/{model_dir}/cbm_clip_{backbone} \
--base_save_path saved_models_gpt4/cifar10/{model_dir} --vae_train_set both --vae_hidden_dim 512 --vae_latent_dim 256 \
--vae_epochs 450 --save_vae True --device cuda --multiprocess True
