Keywords: Science of language models, phase transition, matrix completion, interpretability, BERT
TL;DR: BERT solves low rank matrix completion in an interpretable manner with a sudden drop in MSE loss.
Abstract: We formulate the low-rank matrix completion problem as a masked language modeling (MLM) task, and train a BERT model to solve this task. We find that BERT succeeds in matrix completion and outperforms the classical nuclear norm minimization method. Moreover, the mean--squared--error (MSE) loss curve displays an early plateau followed by a sudden drop to near-optimal values, despite no changes in the training procedure or hyper-parameters. To gain interpretability insights, we examine the model's predictions, attention heads, and hidden states before and after this transition. We observe that (i) the model transitions from simply copying the masked input to accurately predicting the masked entries; (ii) the attention heads transition to interpretable patterns relevant to the task; and (iii) the embeddings and hidden states encode information relevant to the problem.
Student Paper: Yes
Submission Number: 29
Loading