# entity-tracking-probing
probing experiments to understand entity tracking task


# Prior State Probes

### Script
To cache hidden states
```commandline
./scripts/cache_codellama13b_probe_the_activations.qsub 
./scripts/cache_codellama13b_probe_the_activations_moveContent.qsub  # for codellama with moveContent split
```

To train probe
```commandline
./scripts/load_and_train_probe_llama3_405B.qsub  # for llama405b
./scripts/load_and_train_probe_codellama13b.qsub  # for codellama 13b
./scripts/load_and_train_probe_codellama13b_moveContent.qsub  # for codellama13b with moveContent split
```
### Data
- `data/boxes-dataset-v1/few-shot_boxes_nso_exp2_max3/{train/test}-subsample-states-gpt.jsonl`
- `data/boxes-dataset-v1/few-shot_boxes_nso_exp2_max3_move_content/{train/test}-subsample-states-gpt.jsonl` (for move content one)

(using subsample because it's enough)




# Remove Mechanism
## Ternary Probe Training

In my code, I refer to these ternary probes as `phrase probe`.
### Script
To cache the model representation, see
```commandline
./scripts/cache_codellama13b_phrase_probe_activations.qsub
./scripts/cache_gpt2_phrase_probe_activations.qsub
```
some important arguments here are
- `condition_on`: which token hidden states to condition the probe on: 
  - `object_all_local`: condition on object, local states
  - `number_all_local`: condition on box_id (in code I often refer to as `number`), local states
  - `number_all_cumulative`: condition on box_id, global states. When caching, this and `number_all_local` results in the same cache, so just use `number_all_local` when caching

For `codellama13b`, make sure to use 2gpu torch run distributed w/ 16bit. (8bit cache does not result in good probes).
You will also notice qsub this with `#$ -pe omp 28`, this is needed because we are storing a lot of hidden states needs lots of memories.

Now to load and train the probes, see
```commandline
./scripts/load_and_train_phrase_probe_codellama13b.qsub
./scripts/load_and_train_phrase_probe_gpt2.qsub
```

Since we need to train #layers amount of probes, for `codellama13b` I usually submit 4 jobs, each for-looping 10 probes (each probe takes around 20-30min to train)
but customize however you want

### Data
The training data used here is in `entity-tracking-gemma/data/boxes_altAlways_default_maxop12_5k`.
And specifically training uses `train-gpt.jsonl` and test uses `test-subsample-states-gpt.jsonl`. We need full train split
because the class label is very imbalanced with 700 probes.

## Intervention with Ternary Probes
Before running intervention, we need to run baseline model inference to 1) get model behavioral accuracy and 2) get 
examples where model succeeds. The most important scripts are 
```commandline
./scripts/intervene_phrase_probe_codellama13b_8bit_null_1put.qsub  # null the 1 put operation in query box
./scripts/intervene_phrase_probe_codellama13b_8bit_null_1remove.qsub  # null the 1 remove operation in query box
./scripts/intervene_phrase_probe_codellama13b_8bit_null_1remove_put_globally_removed.qsub  # for specific dataset with putting globally removed object in
./scripts/intervene_phrase_probe_gpt2_null_1remove.qsub  # for gpt2
```

since we are only doing 100 examples in most cases, these should be <10 min each run/layer


## Ternary Probe Analysis
All functions should be in 
```commandline
python -m src.analysis.plot_phrase_probe_results
```
They should produce plots in `plots/`, just comment/uncomment functions you need
Functions include 
- plot probe accuracy across layers
- plot probe similarities across layers
- code that looks at basis reconstruction 
- code that correlate probe error to behavioral error
- code that looks at probe error across operations
- code that looks at probe error types (what heursitics is using)
- code that looks at probe error across context/operation index
- plot intervention success across layers


# Utilities that maybe helpful

## Anything Data format related
Conversion between json/tsv, subsampling, checkout 
```commandline
./entity-tracking-probing/utils/*
```
Plotting/analysis scripts
```commandline
entity-tracking-probing/*
```
run them like `python -m src.analysis.plot_phrase_probe_results`
