Improving End-to-End Training of Retrieval-Augmented Generation Models via Joint Stochastic Approximation
Abstract: Retrieval-augmented generation (RAG) has become a widely recognized paradigm to combine parametric memory with non-parametric memories.
An RAG model consists of two serial connecting components (retriever and generator).
A major challenge in end-to-end optimization of the RAG model is that marginalization over relevant passages (modeled as discrete latent variables) from a knowledge base is required.
Traditional top-K marginalization and variational RAG (VRAG) suffer from biased or high-variance gradient estimates.
In this paper, we propose and develop joint stochastic approximation (JSA) based end-to-end training of RAG, which is referred to as JSA-RAG.
The JSA algorithm is a stochastic extension of the EM (expectation-maximization) algorithm and is particularly powerful in estimating discrete latent variable models.
Extensive experiments are conducted on five datasets for two tasks (open-domain question answering, knowledge-grounded dialogs) and show that JSA-RAG significantly outperforms both vanilla RAG and VRAG.
Further analysis shows the efficacy of JSA-RAG from the perspectives of generation, retrieval, and low-variance gradient estimate.
Paper Type: Long
Research Area: Machine Learning for NLP
Research Area Keywords: retrieval-augmented generation
Contribution Types: NLP engineering experiment, Publicly available software and/or pre-trained models, Theory
Languages Studied: English
Submission Number: 920
Loading