January 9, 2023

Atlas: Few-shot learning with retrieval augmented language models

By: Gautier Izacard, Patrick Lewis, Maria Lomeli, Lucas Hosseini, Fabio Petroni, Timo Schick, Jane Dwivedi-Yu, Armand Joulin, Sebastian Riedel, Edouard Grave

TLDR

We released the code for our Atlas project [1] on GitHub, as well as pretrained Atlas model checkpoints, an index, and Wikipedia corpora. We present how to build a Q&A system that is trained using 100 examples and that has a small memory footprint thanks to our codebase’s usability features.

Motivation

Open domain question answering refers to the task of providing answers to natural language questions about a wide variety of domains and topics. Recently, purely parametric approaches based on large language models have made great progress on this task. These approaches are still limited by the fact that all the knowledge is stored in the weights of the model. Consequently, all the knowledge about the world is static, and these models can sometimes generate nonfactual answers.

On the other hand, retrieval-augmented language models combine the best of generative models and search. Retrieval-augmented models use a large collection of text documents from which evidence is retrieved before generating the answers. This enables the model to update its knowledge about the world easily, often leading to more factual answers. These models rely on two parts: (1) the retriever, which given a query or question returns a small set of relevant documents from the large corpus, and (2) the reader, which performs reading comprehension based on the query and the retrieved documents to generate an answer to the given query.

While they’ve obtained promising results on standard Q&A benchmarks, retrieval-augmented models still have limitations. First, these models need to build an index that is then used for retrieval. This leads to large memory requirements, especially when retrieving from large collections that can contain billions of documents. Second, both the retriever and the reader need to be trained on the target task, which often requires many examples to obtain strong performance. In contrast, large language models — which have exhibited zero and few-shot abilities — are able to perform tasks from only a few examples. To address these issues we introduce Atlas, a novel retrieval-augmented model.

What is Atlas?

Atlas is a retrieval-augmented language model pretrained on unlabeled data that exhibits few-shot abilities on knowledge-intensive tasks, such as Q&A and fact-checking. This new model builds on several recent research projects from FAIR (Fundamental AI Research). For retrieval, Atlas uses a dense retriever, based on a bi-encoder architecture, and is initialized with the contriever model [2]. The reader is based on a sequence-to-sequence model, which is initialized with the T5 model and uses the FiD [3] architecture to efficiently process a large number of retrieved documents. These two components — the retriever and the reader — are then jointly pre-trained using a MLM (masked language modeling) task. Knowledge distillation [4,5] is used to train the retriever component, using signals from the reader.

Atlas is a competitive model that can be used to build a Q&A system from a few examples based on the following properties:

  • Atlas is based on a pretrained language model, so it already contains a large amount of information in its language model parameters. It only requires fine-tuning on a few examples to learn how to solve the Q&A task.
  • Atlas uses efficient fine-tuning strategies, which allows it to adapt the retriever to the downstream task without needing to recompute the full index when updating the retriever. Namely, it uses query-side fine-tuning, which fixes the passage encoder and trains only the query encoder. This avoids the need to recompute the full index after each backpropagation step.
  • Atlas can use different Faiss (Facebook AI Similarity Search) compressed indexes. For example, we use the one based on product quantization [7] that allows us to reduce the memory requirements to fine-tune and run Atlas models.

The following sections discuss how to use these techniques in practice to build a Q&A system from 100 training examples.

Using the Atlas codebase to build a Q&A system

To use the Atlas codebase, we are required to do basic setup of the repository. We have released the pretrained Atlas model checkpoints, together with the passage embeddings, and passages for a variety of Wikipedia corpuses and model sizes. For this blog post, we used a language model of size XL (3 billion parameters) trained using the Wikipedia December 2018 corpus. The queries for the NQ task can be also downloaded from our source directories but are generally available from the original publication [6].

In a Linux machine, the Atlas repository is cloned and the conda environment is created:

git clone https://github.com/facebookresearch/atlas.git
cd atlas
conda create --name atlas-env python=3.8
conda activate atlas-env
conda install pytorch==1.11.0 cudatoolkit=11.3 -c pytorch
conda install -c pytorch faiss-gpu=1.7.3 cudatoolkit=11.3
pip install -r requirements.txt

The $DATA_DIR folder is used to store everything we download.

The corpus that was used to train the model can be downloaded with:

python preprocessing/download_corpus.py --corpus corpora/wiki/enwiki-dec2018 --output_directory $DATA_DIR

To download the model checkpoints:

python preprocessing/download_model.py --model models/atlas/xl --output_directory $DATA_DIR

To download the prebuilt indexes:

python preprocessing/download_index.py --index indices/atlas/wiki/xl --output_directory $DATA_DIR

Finally, to download the NQ data:

python preprocessing/prepare_qa.py --output_directory $DATA_DIR

The above scripts download a corpus to $DATA_DIR/corpora/wiki/enwiki-dec2018, the requested model to $DATA_DIR/models/atlas/xl, the pretrained indexes to $DATA_DIR/indices/atlas/wiki/xl, and lastly, the NQ data to $DATA_DIR/nq-data.

Once the pretrained model and datasets are available, we can do the few-shot finetuning and conduct an evaluation of the model with respect to some metrics of interest for the Q&A task.

To do this, we create our own dataset of 100 questions and answers in a jsonl file. This file should be in the following format:

{"question": "when is season 2 of punisher coming out on netflix", "answers": ["in 2019"]}

For this example, we collected 100 questions and answers from the NQ data at random. The corresponding jsonl file is in $DATA_DIR/nq_data/train.100-shot.jsonl. Then, we create a folder to collect the experiments output and the run logs:

mkdir $DATA_DIR/experiments

We now pass all the paths where the models and data were downloaded to the torchrun command. In a machine with 8 GPUs (A100 with 80 GB of RAM), we can run the following to conduct the finetuning:

torchrun --standalone --nnodes 1 --nproc_per_node 8  finetune_qa.py --train_data 
$DATA_DIR/nq_data/train.100-shot.jsonl --eval_data $DATA_DIR/nq_data/test.jsonl --name
"my_finetuning_experiment_baseline_xl" --checkpoint_dir $DATA_DIR/experiments/ --total_steps 31
--index_mode flat --model_path $DATA_DIR/models/atlas/xl --load_index_path
$DATA_DIR/indices/atlas/wiki/xl --reader_model_type google/t5-xl-lm-adapt

Alternatively, we can conduct the finetuning using a product quantized index, with light compression corresponding to a code size of 192. In order to use a Faiss product-quantized index, we select --index_mode faiss and --faiss_index_type pq respectively and compression parameter --faiss_code_size 192:.

torchrun --standalone --nnodes 1 --nproc_per_node 8  finetune_qa.py --train_data 
$DATA_DIR/nq_data/train.100-shot.jsonl --eval_data $DATA_DIR/nq_data/test.jsonl --name
"my_finetuning_experiment_pq_192_xl" --checkpoint_dir $DATA_DIR/experiments/ --total_steps 31
--index_mode faiss --faiss_index_type pq --faiss_code_size 192 --model_path
$DATA_DIR/models/atlas/xl --load_index_path $DATA_DIR/indices/atlas/wiki/xl --reader_model_type
google/t5-xl-lm-adapt

We can also use a more aggressive compression parameter — corresponding to a code size of 64 — and conduct the finetuning with the product quantized index using only 2 GPUs:

torchrun --standalone --nnodes 1 --nproc_per_node 2  finetune_qa.py --train_data
$DATA_DIR/nq_data/train.100-shot.jsonl --eval_data $DATA_DIR/nq_data/test.jsonl --name
"my_finetuning_experiment_pq_64_xl"  --checkpoint_dir $DATA_DIR/experiments/ --total_steps 31
--index_mode faiss --faiss_index_type pq --faiss_code_size 64 --model_path $DATA_DIR/models/atlas/xl
--load_index_path $DATA_DIR/indices/atlas/wiki/xl --reader_model_type google/t5-xl-lm-adapt

In all three experiments, we finetuned the model for –-total_steps 31. We can optionally provide a --save_index_path to save our Faiss compressed indexes as well as document embeddings. These can then be reused if needed to do further fine-tuning. The generated answers and retrieved passages are stored in a .jsonl file called test-step-<TOTAL_STEPS>.jsonl. This file can be useful to compute additional metrics of interest, such as recall@r. The recall@r metric measures the percentage of examples that retrieve at least one passage with the correct answer within the first (r) candidate results. Here, we used the exact metric that counts the number of generated answers that exactly match one of the true answers.

Table 1. Exact match for a flat index vs. product quantized indexes with different compression parameters.

Table 1 contains the exact match metric and memory requirements for the three experiments: the baseline case, no compression versus two different compression levels, low and high. We observed that we obtain memory savings in both cases where we use product quantized indexes. In the low compression case, we match the baseline accuracy. In the high compression case, even though we lose points in accuracy, we run the experiment with significantly less resources (2 GPUs vs 8 GPUs). Please refer to section 5 of the paper "Atlas: Few-shot Learning with Retrieval Augmented Language Models" for more details about the effect of the --faiss_code_size parameter on this metric.

Conclusion

We’ve provided a step-by-step overview of some of the features of the Atlas codebase for a Q&A task, highlighting its ease of use. The Atlas retrieval-augmented model is suitable for a variety of knowledge intensive tasks as well — please refer to the GitHub repository for further details about other tasks.

References

[1] Gautier Izacard, Patrick Lewis, Maria Lomeli, Lucas Hosseini, Fabio Petroni, Timo Schick, Jane Dwivedi-Yu, Armand Joulin, Sebastian Riedel, and Edouard Grave. Few-shot learning with retrieval augmented language models, 2022.

[2] Gautier Izacard, Mathilde Caron, Lucas Hosseini, Sebastian Riedel, Piotr Bojanowski, Armand Joulin, and Edouard Grave. Unsupervised dense information retrieval with contrastive learning. TMLR, 2022.

[3] Gautier Izacard and Edouard Grave. Leveraging passage retrieval with generative models for open domain question answering. arXiv preprint arXiv:2007.01282, 2020.

[4] Gautier Izacard and Edouard Grave. Distilling knowledge from reader to retriever for question answering. International Conference on Learning Representations, 2021.

[5] Devendra Singh Sachan, Siva Reddy, William Hamilton, Chris Dyer, and Dani Yogatama. End-to-end training of multi-document reader and retriever for open-domain question answering, 2021.

[6] Tom Kwiatkowski, Jennimaria Palomaki, Olivia Redfield, Michael Collins, Ankur Parikh, Chris Alberti, Danielle Epstein, Illia Polosukhin, Jacob Devlin, Kenton Lee, Kristina Toutanova, Llion Jones, Matthew Kelcey, Ming-Wei Chang, Andrew M. Dai, Jakob Uszkoreit, Quoc Le, Slav Petrov. Natural questions: a benchmark for question answering research, Transactions of the ACL, 2019.

[7] Herve Jégou, Matthijs Douze, Cordelia Schmid, Product Quantization for Nearest Neighbor Search, IEEE Transactions on Pattern Analysis and Machine Intelligence, 2011.