Onegen Enables a Single LLM to Handle both Retrieval and Generation Simultaneously
Credits:
Jintian Zhang – Cheng Peng – Mengshu Sun – Xiang Chen – Lei Liang – Zhiqiang Zhang – Jun Zhou – Huajun Chen – Ningyu Zhang – Zhejiang University – Ant Group – Zhejiang University
Abstract
Despite the recent advancements in Large Language Models (LLMs), which have significantly enhanced the generative capabilities for various NLP tasks, LLMs still face limitations in directly handling retrieval tasks. However, many practical applications demand the seamless integration of both retrieval and generation. This paper introduces a novel and efficient One-pass Generation and retrieval framework (OneGen), designed to improve LLMs’ performance on tasks that require both generation and retrieval. The proposed framework bridges the traditionally separate training approaches for generation and retrieval by incorporating retrieval tokens generated autoregressively. This enables a single LLM to handle both tasks simultaneously in a unified forward pass. We conduct experiments on two distinct types of composite tasks, RAG and Entity Linking, to validate the pluggability, effectiveness, and efficiency of OneGen in training and inference. Furthermore, our results show that integrating generation and retrieval within the same context preserves the generative capabilities of LLMs while improving retrieval performance. To the best of our knowledge, OneGen is the first to enable LLMs to conduct vector retrieval during the generation.
1 Introduction
In the era of Large Language Models (LLMs), many Natural Language Processing (NLP) tasks can be reduced to generation, allowing them to be addressed by a single LLM (Zhao et al., 2023; Qin et al., 2023; OpenAI, 2023; Zeng et al., 2024). While LLMs excel in language generation, they still suffer from hallucinations (e.g., factual inaccuracies), stemming from their exclusive reliance on the parametric knowledge they contain (Zhang et al., 2023b; Yao et al., 2023; Tonmoy et al., 2024). One promising approach is Retrieval-Augmented Generation (RAG) (Lewis et al., 2020; Jiang et al., 2023d; Asai et al., 2024; Mao et al., 2024; Gao et al., 2023), which augments the input by retrieving relevant passages based on the query either before or during generation. Other methods (Ding et al., 2024a; Luo et al., 2023a) anchor LLM generation to an external knowledge base through Entity Linking (EL) during or after generation. These systems typically rely on a retriever at various stages of generation. However, due to the separate training paradigms for generation and retrieval, most prior work by Muennighoff et al. (2024) employs a separate model for text embedding. However, this pipeline approach has several drawbacks: i) Deploying and maintaining two separate models introduces additional hardware overhead and increases maintenance costs. ii) The separation of models creates two distinct representational spaces, limiting interaction between the retriever and generator (e.g., LLM) to text (i.e., query). As a result, whether the query is generated by the LLM or input directly by the user, it requires an additional forward pass through the retriever, increasing inference computational costs. iii) In multi-turn dialogues, as illustrated in Figure 1(a), query rewriting is required for follow-up questions like “Who is his wife?”. This rewriting adds inference overhead and risks error propagation if inaccurate. iv) Additionally, the pipeline approach is difficult
∗ Equal Contribution.
† Corresponding Author.
Figure 1: Comparison of Three Methods for RAG Task. (a) Two round dialogs using RAG (Retrieve and Generate twice each). (b) Pipeline approach requiring the deployment of two separate models for retrieval and generation, (c) GritLM (Muennighoff et al., 2024) utilizing a single model with a switching mechanism to integrate retrieval and generation, (d) OneGen (Ours) performing both functions automatically in the same model and the same context.
to optimize end-to-end and requires large amounts of training data, while end-to-end optimization has been shown to yield significant benefits (Lin et al., 2024).
Our work introduces an efficient One-pass unified Generation and retrieval (OneGen) framework to enable an arbitrary LLM to generate and retrieve in one single forward pass. Inspired by the latest success in LLM for text embedding (Wang et al., 2024), we expand the original vocabulary by adding special tokens (i.e. retrieval tokens) and allocate the retrieval task to retrieval tokens generated in an autoregressive manner. During training, retrieval tokens only participate in representation finetuning through contrastive learning (van den Oord et al., 2018; Rendle et al., 2009), whereas other output tokens are trained using language model objectives. At inference time, we use retrieval tokens
for efficient retrieving on demand.
Unlike previous pipeline approaches, which require at least two models for retrieval and generation (as shown in Figure 1(b)), OneGen unifies both tasks into a single model, eliminating the need for a separate retriever. Muennighoff et al. (2024) present Generative Representational Instruction Tuning (GRIT), which aligns with this approach by training one LLM to handle both generative and embedding tasks through different prompts and attention mechanisms, as depicted by the “switch” in Figure 1(c). However, GRIT still necessitates independent forward passes for generation and retrieval tasks, reducing efficiency for tasks that intertwine generation and retrieval.
We evaluate the effectiveness of our method on two main tasks that require both generation and retrieval: RAG (including single-hop QA which needs single-retrieval and multi-hop QA which needs multi-retrieval) and Entity Linking (EL). Empirical results show OneGen outperforms the previous pipeline solutions as well as GRIT where applicable. Specifically, OneGen achieves +1.5pt improvement on average with four Single-hop QA datasets on top of Self-RAG (Asai et al., 2024), +3.3pt F1 on average with two Multi-hop QA datasets under three different 7B-based LLMs, and +3.2pt accuracy on average with 6 out-of-domain entity linking datasets, with less training data. Moreover, further analysis demonstrates OneGen can enhance retrieval capability when jointly trained, with no sacrifice in generation capability. In addition, we demonstrate superior inference speed and memory consumption of OneGen compared with other LLM alternatives, particularly as retrieval frequency increases. In summary, our work makes the following contributions:
i) We propose a training-efficiency, inference-efficiency, and pluggable framework OneGen that is particularly suitable for tasks interleaved with generation and retrieval. ii) Our model, fine-tuned on less training data, demonstrates superior performance on six RAG datasets and six entity linking datasets on average. iii) We demonstrate the efficiency of OneGen at inference, highlighting a significant speed improvement as the length of query increases or retrieval frequency increases, compared to other LLM alternatives. iv) From the perspective of methodology, OneGen is an extension of Generative Instruction Tuning (GIT) and Representative Instruction Tuning (RIT) (as shown in Figure 1(b)). v) We contribute to communities by releasing our dataset as well as code.
2 Preliminaries and related works
Most text-based tasks can be reduced to generation, retrieval, or combination of the two. We first introduce several hybrid tasks and their common solutions in § 2.1. Then, we introduce the three roles of tokens in LLMs in § 2.2. Finally, we further explain the motivation of our method in § 2.3.
2.1 Generation & Reteieval
For NLP problem related to generation or retrieval, user input or a query u = {u1, …, un} and optionally document corpus K = {di} ∥K∥ i=1 are given (e.g., wiki articles), the end goal of the task is to generate sequence output y = {y1, …, ym} or the most relevant documents ϵ from K with respect to u or both. We also assume that each di ∈ ε is aligned to a subsequence or a whole sequence of tokens in u. We summarize the steps and typical input, and output for generation, retrieval, and two hybrid tasks in Table 1.
R → G Task leverages retrieval results to drive generation. In the simplest format, a dense retrieval model (e.g., a dense passage retriever, DPR) is used to retrieve a collection of relevant documents ε given user input u at t=1; ε are then used as additional context when generating the target sequence using a generator (e.g. LLM) at t=2. Retrieval Argumented Generation (RAG) is a classic example of R → G task. Though there are some efforts in training the two model end-to-end predate the LLM era (Lewis et al., 2020), most recent work use an off-the-shelf-retriever such as Contriever (Izacard et al., 2022), BM25, or search engine (Jiang et al., 2023d). Furthermore, this task can involve multiple iterations of retrieval and generation, such as in multi-hop reasoning datasets like 2WIKI (Ho et al., 2020) and HotpotQA (Yang et al., 2018).
G → R Task outputs retrieved documents relevant to user query in addition to generated content and are widely encountered in Information Retrieval (IR). A prominent example task is Entity Linking (EL), which involves locating mentions and disambiguating these surface forms into entities in some Knowledge Base (KB). Early EL methods (Hoffmann et al., 2011) treat EL as decomposed subtasks, such as Mention Detection (MD) and Entity Disambiguation (ED), and solve them in sequence. More recent works manage to frame EL as an end-to-end task, such as sequence generation (Cao et al., 2021b), question answering (Zhang et al., 2022), retrieve augmented generation (Xiao et al., 2023), and sequence tagging problem (Broscheit, 2019; Ayoola et al., 2022), which outperform the early pipeline approach. For the generative EL paradigm, MD can be modeled as a generation task where entities in the original sentences are generated; ED is a typical retrieval task of retrieving the most relevant entity from the KB given a mention span.
2.2 Roles of Tokens in LLMS
A token xi is the basic unit processed by an LLM. Token in the input of an LLM serves three different roles: 1) generating the next token, noted as role(xi) = GEN; 2) providing context information, noted as role(xi) = CTX; and 3) representing a sentence, noted as role(xi) = RET. Recent works (Wang et al., 2024; Muennighoff et al., 2024) use the hidden state of the last token as the sentence representation.
2.3 Motivation
Recent years have seen a rise in using LLMs to handle complex hybrid tasks, replacing traditional NLP model pipelines. Before LLMs, end-to-end approaches offered advantages for combining generation and retrieval tasks, reducing error propagation compared to pipelines and potentially improving efficiency with single-pass inference. However, earlier solutions are often task-specific and lack generalization across hybrid tasks. For instance, in generative EL, methods like constrained decoding (Cao et al., 2021b) are used to retrieve entities efficiently. Our work addresses the absence of a unified LLM framework for hybrid tasks, stemming from separate training approaches for generation and retrieval tasks, which typically use distinct objectives and datasets.
Figure 2: The training framework of unified One-pass Generation and retrieval (OneGen), illustrated using RAG. Detailed training process for other tasks can be found in Figure 6 of Appendix.
3.0 Onegen: One-Pass Generation and Retrieval for LLMS
We introduce a One-pass Generation and retrieval framework (OneGen) for fine-tuning LLMs on generation, retrieval, or hybrid tasks, as shown in Figure 2. Our core idea is to integrate generation and retrieval to the same context by allocating the retrieval task to retrieval tokens generated in an autoregressive manner, thus enabling LLM to perform both tasks in a single forward pass.
3.1 overview
Notation. To ensure clarity and precision in our subsequent discussions, we standardize the notation used in Table 1. Define the dataset D = {si} |D| i=1, which consists of |D| sentences s of varying lengths, with each sentence s = {xi} |s| i=0 comprising |s| tokens x. Let xi,j denote the j-th token of the i-th sentence in dataset D and define xi,≤j as {xi,1, xi,2, . . . , xi,j}. We can distinguish the symbols u, y, and d defined in Table 1 based on the role of tokens x within the sentence s. Specifically, y corresponds to the segment of the s where role(x) = GEN, u corresponds to the segment where role(x) = CTX, and if all tokens x in a sentence s have role(x) = CTX, then it corresponds to d. Given the instruction dataset I, where s = {u, y} ∈ I, we have D = I ∪ K.
Design. Retrieval requires encoding both the query and the document within the same representational space. Our core idea is to incorporate query encoding into the generation process. Thus we use the same LLM for encoding both the query and the document, without altering the model structure, such as the attention mechanism, unlike the approach taken by GritLM. Specifically, for query encoding, we introduce a special token xi = [RQ], where role(xi) = RET. This token is generated by the LLM and used as input to represent the query. However, assigning role(xi) = RET prevents the generation of the next token xi+1 if role(xi+1) = GEN. To address this, we also introduce a
3.2 Train
Data Reconstruction. We augment the standard generation output with retrieval tokens wherever retrieval is needed. This makes our framework easily pluggable to existing methods. Generally, we insert [RQ] to sentence s for query representation. In particular, if the query span is explicit, we add optional tokens
Training Objective. The optimization is only performed for tokens xi ∈ s where role(xi) ∈ {GEN, RET}. A simple application of OneGen in the RAG task is illustrated in Figure 2. Note that, role(xi) = RET iff xi ∈ {[RQ], [RD]} (highlight in purple in Figure 2). For tokens where role(xi) = GEN (highlight in orange), optimization employs Lg:
Here, θ is the LLM parameter. πHead ∈ RN×d denotes the expanded vocabulary (i.e., LM Head), consists of N d-dimensional vectors. fθ\πHead (x(i,≤j)) ∈ Rd denotes a d-dimensional vector generated by the LLM without LM Head from processing the first to the j-th token. ℓg typically represents the cross-entropy loss function, and 1g(xi,j ) is an indicator function, where 1g(xi,j ) = 1 iff role(xi,j ) = GEN; otherwise, it is 0. For tokens where role(xi) = RET, optimization employs Lr:
San Jose, CA - New York, NY