Grounding by Trying: LLMs with Reinforcement Learning-Enhanced Retrieval

1Stanford University 2Databricks 3Physical Intelligence 4Google DeepMind
Interpolate start reference image.

LeReT significantly improves retrieval and generation.

Abstract

Mitigating hallucinations is a prerequisite for trusting answers generated by large language models (LLMs) that are prone to making convincing but inaccurate claims. Grounding the answers in data generated and verified by humans provides a natural avenue for improving the reliability of LLMs. However, it can be hard to capture relevant facts for user questions based on just the semantic similarity, especially as questions becomes more complex and the relevant facts become more indirect. What if LLMs could query for relevant facts based on the user question? While this can enable retrieving relevant but indirect facts, zero-shot performance of instruction-tuned LLMs leaves more to be desired and generating supervision on how to retrieve relevant facts can be expensive and retriever dependent. Our key insight is that LLMs can learn to retrieve relevant facts by different queries, learning to upweight queries that result in relevant facts. This leads to our reinforcement learning based framework, Learning to Retrieve by Trying ( LeReT), where the LLM generates queries for multi-hop retrieval and uses preference-based reinforcement learning to improve the LLM queries. Our experimental results demonstrate that LeReT can improve the absolute retrieval accuracy by up to 29% and the downstream generator evaluations by 17%. The simplicity and flexibility of LeReT allows it to be applied to arbitrary retrievers, and makes it a promising technique for improving general LLM pipelines.

Method Overview

We mainly focus on multi-hop retrieval pipelines, where a LLM is given a question and asked to generate a query which is used to retrieve documents. The retrieved documents are then used to generate the next query. Finally, all the retrieved documents are fed into another LLM to answer the original question. LeReT improves LLM's ability to generate effective queries by sampling a diverse set of queries and training the LLM on this dataset using preference-based optimization.
Interpolate start reference image.

LeReT utilizes few shot prompting to sample diverse and effective search queries and then uses the retrieval reward to rank queries. The collected preference pairs are then used to fine tune the model.

Prompt Driven Diverse Query Generation

We introduce a new method to sample a diverse set of retrieval queries on multi-hop pipelines. Specifically, for a given question we sample a diverse set of queries by prompting the model using different few shot prompts created by DSPy. From each query, we retrieve a set of relevant facts and use a reward metric on the retireved facts to create preference pairs. Additionally, for multi-hop retrieval, query generation is conditioned on facts retrieved in previous hops. For each hop, we randomly sample a set of facts from the previous hop proportional to the reward.

Model training

We first perform the standard SFT step, which can also be seen as context distillation since the queries were collected with few shot prompts, on the model. We then train the model using IPO.

Results

We test LeReT on two multi-hop datasets HotpotQA and HoVer with two different base models, Llama 3 8b and Gemma 2 9b. We evaluate LeReT's ability to improve downstream generation by feeding the retrieved documents into Llama 3.1 70B.
Interpolate start reference image.

LeReT significantly improves retrieval accuracy and generation across both datasets and base models.

BibTeX

@misc{hsu2024groundingtryingllmsreinforcement,
      title={Grounding by Trying: LLMs with Reinforcement Learning-Enhanced Retrieval}, 
      author={Sheryl Hsu and Omar Khattab and Chelsea Finn and Archit Sharma},
      year={2024}
}