# Merlin-Arthur x Neural Concept Binder

A framework that combines Merlin-Arthur feature selection with Neural Concept Binder (NCB). Instead of operating in pixel space, this implementation works with NCB encodings of images. This repository contains all source code required to reproduce the ncb-based experiments of the paper *Neural Concept Verifier: Scaling Prover-Verifier Games via Concept Encodings*.

### Prerequisites
1. Downloads (from original NCB Repository [here](https://github.com/NeuralBinder/NCB)):
    - Pre-trained models (SysBinder and retrieval corpus): [Download Link, not working](https://hessenbox.tu-darmstadt.de/getlink/fi6WzuWtQ87Px5P3ewEVNQyZ/CLEVR-4.zip)
    - CLEVR-Hans3 dataset: [Download Link, not working](https://hessenbox.tu-darmstadt.de/getlink/fiDrNmUD4dMihFzhg8qigguo/CLEVR-4-4-cls.zip)
    - CLEVR-Hans7 dataset: [PROVIDE Download Link]()
    - Nonconfounded CLEVR-Hans3 dataset: [PROVIDE Download Link]()
    - Nonconfounded CLEVR-Hans7 dataset: [PROVIDE Download Link]()

2. Environment setup:
    - `conda env create -f environment.yaml -n myenv`
    - `pip install torch tqdm rtpt wandb torchvision`

### Repository Structure
```
.
├── README.md
├── environment.yaml
├── scripts/
│   ├── precompute_encodings/
│   ├── train_regular/
│   ├── train_sfw/
│   ├── train_with_learnable_fs/
│   └── pixel_space/
│       ├── train_regular/
│       ├── train_unet/
│       └── train_sfw/
├── src/
│   ├── config/
│   │   ├── argparser.py
│   │   ├── config.py
│   │   └── config_dataclass.py
│   ├── data_ncb/
│   │   ├── sysbinder/
│   │   └── precompute_encodings.py
│   │   └── CLEVR_Hans_image_dataset.py
│   │   └── argparser_precompute_encodings.py
│   │   └── neural_concept_binder.py
│   ├── merlin_arthur_framework/
│   │   └── feature_selectors.py
│   │   └── stochastic_frank_wolfe.py
│   ├── models/
│   │   └── classifier.py
│   │   └── feature_selector_models.py
│   ├── pixel_space/
│   │   ├── argparser_pixel.py
│   │   ├── classifier_pixel.py
│   │   ├── config_dataclass_pixel.py
│   │   ├── config_pixel.py
│   │   └── feature_selectors_pixel.py
│   │   └── trainer_framework_pixel.py
│   ├── trainer/
│   │   └── trainer_framework.py
│   ├── utils/
│   │   └── logger.py
│   │   └── metrics.py
│   └── main.py
│   └── main_pixel.py
└── tests/
    └── feature_analysis/
        └── feature_analysis.py
```

### Data Preparation

1. **Generate Concept-Slot Encodings**
    - First, generate concept-slot encodings from raw images using NCB:
    ```bash
    # Paths - modify these according to your setup
    DATA_DIR="path/to/CLEVR_Hans3_4"                   # Path to downloaded CLEVR-Hans3 dataset
    RESULT_DIR="path/to/save/concept_slot_encodings"   # Where to save the generated encodings

    SYSBINDER_PATH="path/to/best_model.pt"             # Path to downloaded SysBinder model
    RETRIEVAL_CORPUS_PATH="path/to/block_concept_dicts.pkl"  # Path to downloaded retrieval corpus

    ENC_TYPE="concept_slot"
    NUM_WORKERS=4
    BATCH_SIZE=128

    python src/data_ncb/precompute_encodings.py \
        --enc_type ${ENC_TYPE} \
        --data_dir ${DATA_DIR} \
        --sysbinder_path ${SYSBINDER_PATH} \
        --result_dir ${RESULT_DIR} \
        --retrieval_corpus_path ${RETRIEVAL_CORPUS_PATH} \
        --num_workers ${NUM_WORKERS} \
        --batch_size ${BATCH_SIZE} \
    ```

2. **Generate One-Hot-Padded Encodings**
    - Convert concept-slot encodings to one-hot-padded format:
    ```bash
    # Paths
    # DATA_DIR: Directory where the concept-slot encodings are stored (not images)
    DATA_DIR="path/to/concept_slot_encodings"          # Path to encodings from step 1
    RESULT_DIR="path/to/save/one_hot_padded_encodings" # Where to save the converted encodings

    SYSBINDER_PATH="path/to/best_model.pt"             # Path to downloaded SysBinder model
    RETRIEVAL_CORPUS_PATH="path/to/block_concept_dicts.pkl"  # Path to downloaded retrieval corpus

    # Required arguments
    ENC_TYPE="one_hot_padded"
    NUM_WORKERS=4
    BATCH_SIZE=128

    python src/data_ncb/precompute_encodings.py \
        --enc_type ${ENC_TYPE} \
        --data_dir ${DATA_DIR} \
        --sysbinder_path ${SYSBINDER_PATH} \
        --result_dir ${RESULT_DIR} \
        --retrieval_corpus_path ${RETRIEVAL_CORPUS_PATH} \
        --num_workers ${NUM_WORKERS} \
        --batch_size ${BATCH_SIZE} \
    ```
    > Note: This step requires the concept-slot encodings from step 1.

### Training Methods

1. **Regular Training (Arthur Only)**
    - Basic training without feature selection:
    ```bash
    # Paths
    DATA_DIR="path/to/one_hot_padded_encodings"        # Path to encodings from step 2
    RES_DIR="path/to/save/checkpoints"                 # Where to save model checkpoints

    # Required arguments
    ENC_TYPE="one_hot_padded"
    MODEL="SetTransformer"  # can also be MLP or Linear

    python src/main.py \
        --epochs 30 \
        --approach "regular" \
        --lr 0.0001 \
        --seed 42 \
        --data_dir $DATA_DIR \
        --enc_type $ENC_TYPE \
        --res_dir $RES_DIR \
        --batch_size 128 \
        --num_workers 4 \
        --model $MODEL \
        --n_heads 4 \
        --set_transf_hidden 128 \
        --save_model
    ```

2. **Stochastic Frank-Wolfe (SFW) Training**
    - Feature selection using optimization-based approach with Merlin and Morgana:
    ```bash
    # Paths
    DATA_DIR="path/to/one_hot_padded_encodings"        # Path to encodings from step 2
    PRETRAINED_PATH="path/to/pretrained/model.pth"     # Path to pretrained model from regular training
    RES_DIR="path/to/save/checkpoints"                 # Where to save model checkpoints

    # Required arguments
    ENC_TYPE="one_hot_padded"
    MODEL="SetTransformer"  # can also be MLP or Linear
    MASK_SIZE=6            # Number of features to select

    python src/main.py \
        --epochs 15 \
        --approach "sfw" \
        --lr 0.001 \
        --seed 42 \
        --data_dir $DATA_DIR \
        --enc_type $ENC_TYPE \
        --res_dir $RES_DIR \
        --batch_size 512 \
        --num_workers 4 \
        --model $MODEL \
        --n_heads 4 \
        --set_transf_hidden 128 \
        --pretrained_model \
        --pretrained_path $PRETRAINED_PATH \
        --mask_size $MASK_SIZE \
        --l1_penalty_coefficient 0.1 \
        --gamma 1.5 \
        --lr_merlin 0.01 \
        --lr_morgana 0.01 \
        --compute_prec_and_ent \
    ```

3. **Learnable Feature Selection Training**
    - Neural network-based feature selection with learnable Merlin and Morgana:
    ```bash
    # both Classifier and Feature Selector are SetTransformer

    # Paths
    DATA_DIR="path/to/one_hot_padded_encodings"        # Path to encodings from step 2
    PRETRAINED_PATH="path/to/pretrained/model.pth"     # Path to pretrained model from regular training
    RES_DIR="path/to/save/checkpoints"                 # Where to save model checkpoints

    # Required arguments
    ENC_TYPE="one_hot_padded"
    MODEL="SetTransformer"  # Classifier model type
    FS_MODEL="SetTransformer"  # Feature selector model type
    MASK_SIZE=6  # Number of features to select

    python src/main.py \
        --epochs 30 \
        --approach "learn_fs" \
        --lr 0.001 \
        --seed 42 \
        --data_dir $DATA_DIR \
        --enc_type $ENC_TYPE \
        --res_dir $RES_DIR \
        --batch_size 512 \
        --num_workers 4 \
        --model $MODEL \
        --n_heads 4 \
        --set_transf_hidden 128 \
        --pretrained_model \
        --pretrained_path $PRETRAINED_PATH \
        --mask_size $MASK_SIZE \
        --gamma 0.1 \
        --lr_merlin 0.001 \
        --lr_morgana 0.001 \
        --fs_model $FS_MODEL \
        --fs_hidden_dim 256 \
        --fs_dropout 0.2 \
        --fs_n_heads 4 \
        --weight_decay_merlin 0.0001 \
        --weight_decay_morgana 0.0001 \
        --l1_penalty_coefficient 1 \
        --compute_prec_and_ent \
    ```

4. **Learnable FS Training for the partially confounded setup**
    - Train model on versions with varying ratios (e.g. 5%) of clean samples (i.e., without confounder) in the training and validation sets:
    ```bash
    # Paths
    DATA_DIR="path/to/one_hot_padded_encodings"                     # Path to encodings from step 2
    PRETRAINED_PATH="path/to/pretrained/model.pth"                  # Path to pretrained model from regular training
    RES_DIR="path/to/save/checkpoints"                              # Where to save model checkpoints
    PART_CONF_DIR="path/to/nonconfounded_one_hot_padded_encodings"  # Path to encodings that are nonconfounded (clean), do steps 1 & 2 for the nonconfounded dataset

    # Required arguments
    ENC_TYPE="one_hot_padded"
    MODEL="SetTransformer"  # Classifier model type
    FS_MODEL="SetTransformer"  # Feature selector model type
    MASK_SIZE=6  # Number of features to select
    PART_CONF_RATIO=0.05  # Ratio of clean samples, e.g. 5%

    python src/main.py \
        --approach learn_fs \
        --batch_size 512 \
        --data_dir $DATA_DIR \
        --enc_type one_hot_padded \
        --epochs 50 \
        --fs_dropout 0.3 \
        --fs_hidden_dim 256 \
        --fs_model SetTransformer \
        --fs_n_heads 4 \
        --gamma 1 \
        --l1_penalty_coefficient 1 \
        --lr 0.001 \
        --lr_fs 0.001 \
        --mask_size $MASK_SIZE \
        --model SetTransformer \
        --n_heads 4 \
        --num_workers 4 \
        --partial_conf_dir $PART_CONF_DIR \
        --partial_conf_ratio $PART_CONF_RATIO \
        --pretrained_path $PRETRAINED_PATH \
        --res_dir $RES_DIR \
        --seed 42 \
        --set_transf_hidden 128 \
        --weight_decay 0.0001 \
        --weight_decay_fs 0.0001 \
        --pretrained_model \
    ```

### Approach and Model Options
- `--approach`: `'regular'` | `'sfw'` | `'learn_fs'` 
- `--model` (classifier): `'linear'` | `'settransformer'` | `'mlp'`
  - `settransformer`: uses `'n_heads'`, `'set_transf_hidden'`
  - `mlp`: uses `'hidden_dim'`, `'dropout'`
- `--fs_model` (only for `'learn_fs'`): `'mlp'` | `'settransformer'`
  - `settransformer`: uses `'fs_n_heads'`, `'fs_hidden_dim'`, `'fs_dropout'`
  - `mlp`: uses `'fs_hidden_dim'`, `'fs_dropout'`

### Important Boolean Flags
- `--wandb`: Enable wandb logging
- `--save_model`: Save model checkpoints  
- `--save_confusion_matrix`: Save confusion matrix plots
- `--feature_distribution`: Compute feature distribution (only available with logger, e.g. wandb)
- `--compute_prec_and_ent`: Compute precision and entropy metrics
- `--unconf_split`: Split train set into train/val (val becomes test). False keeps original confounded dataset.

### Example Scripts
There is a `scripts` folder containing ready-to-use training scripts for different scenarios and configurations. This includes:

- Encoding precomputation scripts (`scripts/precompute_encodings/`)
- Regular training scripts (`scripts/train_regular/`)
- Stochastic Frank-Wolfe (SFW) training scripts (`scripts/train_sfw/`)
- Learnable feature selection training scripts (`scripts/train_with_learnable_fs/`) 
- Training scripts in pixel space (ResNet18 on raw images) (`scripts/pixel_space/`)

These scripts can serve as templates that you can modify for your specific use case.


