## Requirements
Create the conda environment with `environment.yaml`.

## Usage
* **Step 1: Prepare datasets**
Prepare the datasets in the `datasets/` directory with the .json files. See `utils/data_utils` for details. Create n-shot
datasets for training, corresponding corrupted datasets for FV head localization test, 0-shot datasets for injection. Customize
other datasets as needed.

* **Step 2: Localize the FV head (per task)**
For each task and each shot number (e.g. 3/5/10), run:
```bash
# Template
MODEL_NAME="model"
TASK_NAME="task"
TRAIN_DATASET_NAME = "train_dataset_name"
TEST_DATASET_NAME = "test_dataset_name"
INJECT_DATASET_NAME = "inject_dataset_name"
SHOT=3
SEED=42
DEVICE="cuda"

python3 extract_n_shot_function_vector.py \
  --model_name "${MODEL_NAME}" \
  --task_name "${TASK_NAME}" \
  --train_dataset_name "${TRAIN_DATASET_NAME}" \
  --test_dataset_name "${TEST_DATASET_NAME}" \
  --inject_dataset_name "${INJECT_DATASET_NAME}" \
  --device "${DEVICE}" \
  --seed "${SEED}"

# Example: 5-shot chinese-ambiguous
python3 extract_n_shot_function_vector.py \
  --model_name "google/gemma-2-2b" \
  --task_name "chinese-ambiguous" \
  --train_dataset_name "chinese-ambiguous.json_5_100" \
  --test_dataset_name "chinese-ambiguous.json_5_25_corrupted" \
  --inject_dataset_name "chinese-ambiguous.json_0_100" \
  --device "cuda" \
  --seed 42
```
For batch runs, see script/extract_function_vectors_chinese-ambiguous.sh.

* **Step 3: Run sub-task extraction (per task)**
Customize sub-task in `extract_sub_function_vector_per_prompt_*.py` as needed. For each task and each shot number (e.g. 3/5/10), run:
The results are save under `logs/{model_name}/{task_name}` directory by default. `attn_scores*`, `per_prompt_function_vector*`, 
and `per_prompt_intervened_success_rate*` files will be generated.
```bash
# Full model Q/K + change fully model unambiguous examples’ V to uncontextualized V
#
# Notes:
# - token_types mapping:
#     3-shot -> 6
#     5-shot -> 8
#     10-shot -> 13
# - --corruption_with_ablation is OPTIONAL (only enable it when needed)

python3 extract_sub_function_vector_per_prompt_10_shot.py \
  --model_name "google/gemma-2-2b" \
  --task_name "chinese-ambiguous" \
  --sub_FV_name "full_model_v_all_clean_Ex4710" \
  --token_types 13 \
  --corrupted_dataset_name "chinese-ambiguous.json_10_100" \
  --train_dataset_name "chinese-ambiguous.json_10_100" \
  --inject_dataset_name "chinese-ambiguous.json_0_100" \
  --CIE_path "mean_causal_indirect_effects_chinese-ambiguous.json_10_100_chinese-ambiguous.json_10_25_corrupted.pt" \
  --device "cuda" \
  --seed 42 \
  --corruption_with_ablation
```
More examples are available under the `script/` directory.