On the Sample Complexity of Next-Token Prediction
Abstract: Next-token prediction with cross-entropy loss is the objective of choice in sequence and language modeling. Despite its widespread use, there is a lack of theoretical analysis regarding the generalization of models trained using this objective. In this work, we provide an analysis of empirical risk minimization for sequential inputs generated by order-$k$ Markov chains. Assuming bounded and Lipschitz logit functions, our results show that in-sample prediction error decays optimally with the number of tokens, whereas out-of-sample error incurs an additional term related to the mixing properties of the Markov chain. These rates depend on the statistical complexity of the hypothesis class and can lead to generalization errors that do not scale exponentially with the order of the Markov chain---unlike classical $k$-gram estimators. Finally, we discuss the possibility of achieving generalization rates independent of mixing.
Submission Number: 251
Loading