Abstract: Through end-to-end training to predict the next token, LLMs have become valuable tools for various tasks. Enhancing their core training in language modeling can improve numerous downstream applications.
A successful approach to enhance language modeling uses a separate planning module to predict abstract labels of future sentences and conditions the LM on these predictions. However, this method is non-differentiable, preventing joint end-to-end tuning of the planner with the LM. We propose an effective method to improve this approach by enabling joint fine-tuning of the planner and the LM. We show that a naive way of approximating the gradient of selecting a label via the straight-through estimator is not effective. Instead, we propose to use the predicted label probabilities as mixing weights to condition the LM on a weighted average of label embeddings in a differentiable manner.
This not only enables joint fine-tuning of the planner and the LM, but also allows the LM to draw on the full label distribution predicted by the planner, retaining more information.
Our experimental results show consistent improvements in perplexity.
Submission Length: Regular submission (no more than 12 pages of main content)
Changes Since Last Submission: Updated PDF based on reviewer comments. The attached PDF includes both the updated paper, and a version that highlights which parts changed compared to the original PDF.
Assigned Action Editor: ~Frederic_Sala1
Submission Number: 3476
Loading