Generative Pseudo Labeling for Domain Adaptation of Dense Retrievals

*Note: Adapted to Haystack from Nils Reimers’ original notebook

The NLP models we use every day were trained on a corpus of data that reflects the world from the past. In the meantime, we’ve experienced world-changing events, like the COVID pandemics, and we’d like our models to know about them. Training a model from scratch is tedious work but what if we could just update the models with new data? Generative Pseudo Labeling comes to the rescue.

The example below shows you how to use GPL to fine-tune a model so that it can answer the query: “How is COVID-19 transmitted?”.

We’re using TAS-B: A DistilBERT model that achieves state-of-the-art performance on MS MARCO (500k queries from Bing Search Engine). Both DistilBERT and MS MARCO were created with data from 2018 and before, hence, it lacks the knowledge of any COVID-related information.

For this example, we’re using just four documents. When you ask the model ““How is COVID-19 transmitted?”, here are the answers that you get (dot-score and document):

  • 94.84 Ebola is transmitted via direct contact with blood
  • 92.87 HIV is transmitted via sex or sharing needles
  • 92.31 Corona is transmitted via the air
  • 91.54 Polio is transmitted via contaminated water or food

You can see that the correct document is only third, outranked by Ebola and HIV information. Let’s see how we can make this better.

Efficient Domain Adaptation with GPL

This notebook demonstrates Generative Pseudo Labeling (GPL), an efficient approach to adapt existing dense retrieval models to new domains and data.

We get a collection of 10k scientific papers on COVID-19 and then fine-tune the model within 15-60 minutes (depending on your GPU) so that it includes the COVID knowledge.

If we search again with the updated model, we get the search results we would expect:

  • Query: How is COVID-19 transmitted
  • 97.70 Corona is transmitted via the air
  • 96.71 Ebola is transmitted via direct contact with blood
  • 95.14 Polio is transmitted via contaminated water or food
  • 94.13 HIV is transmitted via sex or sharing needles

Prepare the Environment

Colab: Enable the GPU runtime

Make sure you enable the GPU runtime to experience decent speed in this tutorial. Runtime -> Change Runtime type -> Hardware accelerator -> GPU

!nvidia-smi
!pip install -q datasets
!pip install "faiss-gpu>=1.6.3,<2"
!pip install -q git+https://github.com/deepset-ai/haystack.git

Logging

We configure how logging messages should be displayed and which log level should be used before importing Haystack. Example log message: INFO - haystack.utils.preprocessing - Converting data/tutorial1/218_Olenna_Tyrell.txt Default log level in basicConfig is WARNING so the explicit parameter is not necessary but can be changed easily:

import logging

logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
# We load the TAS-B model, a state-of-the-art model trained on MS MARCO
max_seq_length = 200
model_name = "msmarco-distilbert-base-tas-b"

org_model = SentenceTransformer(model_name)
org_model.max_seq_length = max_seq_length
# We define a simple query and some documents how diseases are transmitted
# As TAS-B was trained on rather out-dated data (2018 and older), it has now idea about COVID-19
# So in the below example, it fails to recognize the relationship between COVID-19 and Corona


def show_examples(model):
    query = "How is COVID-19 transmitted"
    docs = [
        "Corona is transmitted via the air",
        "Ebola is transmitted via direct contact with blood",
        "HIV is transmitted via sex or sharing needles",
        "Polio is transmitted via contaminated water or food",
    ]

    query_emb = model.encode(query)
    docs_emb = model.encode(docs)
    scores = util.dot_score(query_emb, docs_emb)[0]
    doc_scores = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)

    print("Query:", query)
    for doc, score in doc_scores:
        # print(doc, score)
        print(f"{score:0.02f}\t{doc}")


print("Original Model")
show_examples(org_model)

Get Some Data on COVID-19

We select 10k scientific publications (title + abstract) that are connected to COVID-19. As a dataset, we use TREC-COVID-19.

dataset = load_dataset("nreimers/trec-covid", split="train")
num_documents = 10000
corpus = []
for row in dataset:
    if len(row["title"]) > 20 and len(row["text"]) > 100:
        text = row["title"] + " " + row["text"]

        text_lower = text.lower()

        # The dataset also contains many papers on other diseases. To make the training in this demo
        # more efficient, we focus on papers that talk about COVID.
        if "covid" in text_lower or "corona" in text_lower or "sars-cov-2" in text_lower:
            corpus.append(text)

        if len(corpus) >= num_documents:
            break

print("Len Corpus:", len(corpus))

Initialize Haystack Retriever and DocumentStore

Let’s add corpus documents to FAISSDocumentStore and update corpus embeddings via EmbeddingRetriever

from haystack.nodes.retriever import EmbeddingRetriever
from haystack.document_stores import FAISSDocumentStore

document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", similarity="cosine")
document_store.write_documents([{"content": t} for t in corpus])


retriever = EmbeddingRetriever(
    document_store=document_store,
    embedding_model="sentence-transformers/msmarco-distilbert-base-tas-b",
    model_format="sentence_transformers",
    max_seq_len=max_seq_length,
    progress_bar=False,
)
document_store.update_embeddings(retriever)

(Optional) Download Pre-Generated Questions or Generate Them Outside of Haystack

The first step of the GPL algorithm requires us to generate questions for a given text passage. Even though our pre-COVID trained model hasn’t seen any COVID-related content, it can still produce sensible queries by copying words from the input text. As generating questions from 10k documents is a bit slow (depending on the GPU used), we’ll download question/document pairs directly from the Hugging Face hub.

from tqdm.auto import tqdm

query_doc_pairs = []

load_queries_from_hub = True

# Generation of the queries is quite slow in Colab due to the old GPU and the limited CPU
# I pre-computed the queries and uploaded these to the HF dataset hub. Here we just download them
if load_queries_from_hub:
    generated_queries = load_dataset("nreimers/trec-covid-generated-queries", split="train")
    for row in generated_queries:
        query_doc_pairs.append({"question": row["query"], "document": row["doc"]})
else:
    # Load doc2query model
    t5_name = "doc2query/msmarco-t5-base-v1"
    t5_tokenizer = AutoTokenizer.from_pretrained(t5_name)
    t5_model = AutoModelForSeq2SeqLM.from_pretrained(t5_name).cuda()

    batch_size = 32
    queries_per_doc = 3

    for start_idx in tqdm(range(0, len(corpus), batch_size)):
        corpus_batch = corpus[start_idx : start_idx + batch_size]
        enc_inp = t5_tokenizer(
            corpus_batch, max_length=max_seq_length, truncation=True, padding=True, return_tensors="pt"
        )

        outputs = t5_model.generate(
            input_ids=enc_inp["input_ids"].cuda(),
            attention_mask=enc_inp["attention_mask"].cuda(),
            max_length=64,
            do_sample=True,
            top_p=0.95,
            num_return_sequences=queries_per_doc,
        )

        decoded_output = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)

        for idx, query in enumerate(decoded_output):
            corpus_id = int(idx / queries_per_doc)
            query_doc_pairs.append({"question": query, "document": corpus_batch[corpus_id]})


print("Generated queries:", len(query_doc_pairs))

Use PseudoLabelGenerator to Genenerate Retriever Adaptation Training Data

PseudoLabelGenerator run will execute all three steps of the GPL algorithm:

  1. Question generation - optional step
  2. Negative mining
  3. Pseudo labeling (margin scoring)

The output of the PseudoLabelGenerator is the training data we’ll use to adapt our EmbeddingRetriever.

from haystack.nodes.question_generator import QuestionGenerator
from haystack.nodes.label_generator import PseudoLabelGenerator

use_question_generator = False


if use_question_generator:
    questions_producer = QuestionGenerator(
        model_name_or_path="doc2query/msmarco-t5-base-v1",
        max_length=64,
        split_length=128,
        batch_size=32,
        num_queries_per_doc=3,
    )

else:
    questions_producer = query_doc_pairs

# We can use either QuestionGenerator or already generated questions in PseudoLabelGenerator
psg = PseudoLabelGenerator(questions_producer, retriever, max_questions_per_document=10, batch_size=32, top_k=10)
output, pipe_id = psg.run(documents=document_store.get_all_documents())

Update the Retriever

Now that we have the generated training data produced by PseudoLabelGenerator, we’ll update the EmbeddingRetriever. Let’s take a peek at the training data.

output["gpl_labels"][0]
len(output["gpl_labels"])
retriever.train(output["gpl_labels"])

Verify that EmbeddingRetriever Is Adapted and Save It For Future Use

Let’s repeat our query to see if the Retriever learned about COVID and can now rank it as #1 among the answers.

print("Original Model")
show_examples(org_model)

print("\n\nAdapted Model")
show_examples(retriever.embedding_encoder.embedding_model)
retriever.save("adapted_retriever")