Abstract: Continual Pre-Training (CPT) has become a popular and effective method to apply strong foundation models to specific downstream tasks. In this work, we explore the **learning dynamics** throughout the CPT process for large language models (LLMs).
We specifically focus on how general and downstream domain performance evolves at each training step, with domain performance measured via validation losses.
We have observed that the CPT loss curve fundamentally characterizes the transition from one curve to another hidden curve, and could be described by decoupling the effects of distribution shift and learning rate (LR) annealing.
We derive a CPT scaling law that combines the two factors, enabling the prediction of loss at any (continual) training steps and across learning rate schedules (LRS) in CPT.
Our formulation presents a comprehensive understanding of several critical factors in CPT, including the learning rate, the training steps, and the distribution distance between PT and CPT datasets.
Moreover, our approach can be adapted to customize training hyper-parameters to different CPT goals such as balancing general and domain-specific performance.
Extensive experiments demonstrate that our scaling law holds across various CPT datasets and training hyper-parameters.
Lay Summary: Continual Pre-Training (CPT) of large language models aims to enhance their abilities in specific downstream domains (e.g. coding, finance, math) while mitigating the substantial costs associated with re-training. However, understanding how the training process progresses and how different factors influence performance remains unclear.
In this work, we explore the learning dynamics throughout the CPT process. We specifically focus on how general and downstream domain performance evolves at each training step, with domain performance measured via validation losses.
We have observed that the CPT loss curve fundamentally characterizes the transition from one curve to another hidden curve, and could be described by decoupling the effects of distribution shift and learning rate (LR) annealing. We propose a CPT scaling law, that captures these effects and predicts model performance at any training step and under different learning rate schedules.
Our scaling law presents a comprehensive understanding of several critical factors in CPT and can be adapted to optimize training hyper-parameters for different CPT goals, such as balancing general and domain-specific performance.
Primary Area: Deep Learning->Large Language Models
Keywords: Continual Pre-Training, Large Language Models, Learning Dynamics
Submission Number: 5503
Loading