# What parts of the context matter for correct prediction?

We want to know which context tokens are actually needed to correctly predict the next token in a memorized random string.
Understanding this will help us in determining how models memorize random strings, i.e. whether they use full- or local-context information.

## Methodology

We create random strings and train Pythia-1b models to memorize them for 50 epochs.

Then, we randomly sample token positions in the string and perform binary search to find the smallest number of tokens needed to still predict each sampled token correctly.
To do so, we create relevance scores for each token before the target token and then search for a value of $k$, such that using the top-$k$ tokens according to the scores we can correctly predict the target token, but with $k-1$ tokens we cannot.

We use two types of scores here:
- *preceding*: we simply score tokens based on their proximity to the target token, i.e. if the target token is at position $n$, the $n-1$st token has the highest score, followed by the $n-1$nd, etc.
- *attention*: we score tokens based on the attention weights assigned by the model. We investigate attention at three layers, the first, the middle, and the final layer, and use the mean over all attention heads.

Additionally, we hide information about the unused tokens, i.e. the tokens not in the top-$k$ in the prompt in two ways:
- *masked*: we use the attention masks to force the attentions for the respective tokens to be 0. This might be an abuse of attention masks though, since they're usually meant to be used in a block-contiguous manner.
- *shuffled*: we change the unused tokens to random other tokens sampled from the string's alphabet (here the 26 Roman letters). Since the results are sensitive to the choice of replacement letters, we sample 10 different replacements and report averaged scores.

The results below show for each type of context (scoring type + hiding method) the average number of tokens needed to correctly predict the target token, the actual tokens need, and for the 64 character strings additionally the attention maps.
In the plots showing the minimum number of tokens needed for correct prediction, the y-axis shows two numbers: first the index of the target token that is being predicted and second the minimum number of tokens needed to predict it correctly.

## Takeaways

Some interesting observations from the results:
- If we want to pick a minimal context using attention while still predicting the next token correctly, the layer matters. In most cases, picking according to the middle layer's attention (8 for Pythia-1b) yields the smallest number of required context tokens, whereas using the first layer attention needs somewhat more tokens, and using the last layer requires the most context tokens. So there do seem to be differences in the attention patterns at different layers. Note that I just averaged all the attention heads here, so there might be more nuance here as well.
- Using just the closest preceding tokens, i.e. ignoring attention is often the best choice, i.e. it requires the smallest number of context tokens for correct predictions, in almost all cases. Usually, this value and the one from using the middle layer attention are quite close. So it seems that our initial intuition of using preceding tokens was a good one.
- There is a lot of variance in the number of tokens needed to correctly predict the target token. For some tokens only a very small context is needed, but for others the context has to be very large. For example, in the 512 token string with shuffling, token 373 only needs 7 tokens of context and token 469 only 3 tokens, whereas token 306 needs 215 tokens and 327 needs 54 tokens of context. It would be interesting to understand what is causing these large differences.
- Shuffling, i.e. replacing unused tokens with random ones, picked from the string's alphabet, seems to require much fewer context tokens than masking, i.e. setting the attention mask for the unused tokens to 0. For shuffling, the numbers are averages over 10 different replacement samples. I'm not quite sure why this is happening, but I could imagine that using attention masks to hide information about tokens in the middle of the string is using the mechanism in a way that it isn't meant for. Usually you just use it to mask out positions before the start of the actual sequences, so the pattern is always n x 0 + m x 1, but not interleaved 0s and 1s. Maybe that's causing problems.

## Next steps

I'm planning to run a few more variants of this experiment:
- Using preceding + first tokens: when using attention maps at the middle and final layer, we always pick the first token, so it might be worth checking whether by just using the first token and then the ones immediately before the target token we will be able to shrink the required context even further.
- More ways of hiding unused tokens in the context: in addition to masking out and shuffling/perturbing unused tokens I want to look at pruning unused tokens, i.e. to just put all the used ones next to each other in the prompt and to remove the ones not used right now. This might be useful to understand to which extent the relative position of tokens is important, i.e. whether it's important that token a appeared 20 positions before the target token and token b 10 positions before, or whether it's only important that a appeared before b.
