How Do Transformers Fill in the Blanks? A Case Study on Matrix Completion

Published: 18 Jun 2024, Last Modified: 19 Jul 2024TF2M 2024 PosterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Science of language models, phase transition, matrix completion, interpretability, BERT
TL;DR: BERT solves low rank matrix completion in an interpretable manner with a sudden drop in MSE loss.
Abstract: Completing masked sequences is an important problem in language modeling, and analyzing how Transformer models perform this task is crucial for understanding their mechanisms. In this direction, we formulate the low-rank matrix completion problem as a masked language modeling (MLM) task, and train a BERT model to solve this task. We find that BERT succeeds in matrix completion and outperforms the classical nuclear norm minimization method. Moreover, the mean--squared--error (MSE) loss curve displays an early plateau followed by a sudden drop to near-optimal values, despite no changes in the training procedure or hyper-parameters. To gain interpretability insights, we examine the model's predictions, attention heads, and hidden states before and after this transition. Concretely, we observe that (i) the model transitions from simply copying the masked input to accurately predicting the masked entries; (ii) the attention heads transition to interpretable patterns relevant to the task; and (iii) the embeddings and hidden states encode information relevant to the problem.
Submission Number: 12
Loading