Source code for

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, List, Optional, Tuple

from pytext.common.constants import SpecialTokens
from pytext.config.component import ComponentType, create_component
from import BERTTensorizer, build_fairseq_vocab
from import RoBERTaTensorizer
from import LabelTensorizer
from import Tokenizer
from import Vocabulary
from pytext.utils.file_io import PathManager

[docs]class BERTContextTensorizerForDenseRetrieval(BERTTensorizer): """Methods numberize() and tensorize() implement"""
[docs] def numberize(self, row: Dict) -> Tuple[Any, ...]: """ This function contains logic for converting tokens into ids based on the specified vocab. It also outputs, for each instance, the vectors needed to run the actual model. It works off of one sample. """ # don't include passage id positive_ctx = row["positive_ctx"][:-1] positive_ctx_tokens = [ self.tokenizer.tokenize(content) for content in positive_ctx ] ( positive_ctx_token_ids, positive_ctx_segment_labels, positive_ctx_seq_len, positive_ctx_positions, ) = self.tensorizer_script_impl.numberize(positive_ctx_tokens) # don't include passage id negative_ctxs = [ neg_ctx[:-1] for neg_ctx in row["negative_ctxs"] ] # returns List[str] if negative_ctxs and row["num_negative_ctx"] == 1: # currently only num_negative_ctx == 1 is supported negative_ctx_tokens = [ self.tokenizer.tokenize(content) for content in negative_ctxs[0] ] ( negative_ctx_token_ids, negative_ctx_segment_labels, negative_ctx_seq_len, negative_ctx_positions, ) = self.tensorizer_script_impl.numberize(negative_ctx_tokens) else: negative_ctx_token_ids = [] negative_ctx_segment_labels = [] negative_ctx_seq_len = 0 negative_ctx_positions = [] return ( positive_ctx_token_ids, positive_ctx_segment_labels, positive_ctx_seq_len, positive_ctx_positions, negative_ctx_token_ids, negative_ctx_segment_labels, negative_ctx_seq_len, negative_ctx_positions, )
[docs] def tensorize(self, batch): """Works off of a batch that's numerized.""" all_ctx_tokens_2d = [] all_ctx_segment_labels_2d = [] all_ctx_seq_lens_1d = [] all_ctx_positions_2d = [] for ( positive_ctx_token_ids, positive_ctx_segment_labels, positive_ctx_seq_len, positive_ctx_positions, negative_ctx_token_ids, negative_ctx_segment_labels, negative_ctx_seq_len, negative_ctx_positions, ) in batch: # Make sure the positive and hard negative context for a given # question are one after another in the batch. all_ctx_tokens_2d.append(positive_ctx_token_ids) all_ctx_segment_labels_2d.append(positive_ctx_segment_labels) all_ctx_seq_lens_1d.append(positive_ctx_seq_len) all_ctx_positions_2d.append(positive_ctx_positions) if negative_ctx_seq_len > 0: all_ctx_tokens_2d.append(negative_ctx_token_ids) all_ctx_segment_labels_2d.append(negative_ctx_segment_labels) all_ctx_seq_lens_1d.append(negative_ctx_seq_len) all_ctx_positions_2d.append(negative_ctx_positions) return self.tensorizer_script_impl.tensorize_wrapper( all_ctx_tokens_2d, all_ctx_segment_labels_2d, all_ctx_seq_lens_1d, all_ctx_positions_2d, )
[docs]class RoBERTaContextTensorizerForDenseRetrieval( BERTContextTensorizerForDenseRetrieval, RoBERTaTensorizer ):
[docs] class Config(RoBERTaTensorizer.Config): pass
[docs] @classmethod def from_config(cls, config: Config): tokenizer = create_component(ComponentType.TOKENIZER, config.tokenizer) with as file_path: vocab = build_fairseq_vocab( vocab_file=file_path, special_token_replacements={ "<pad>": SpecialTokens.PAD, "<s>": SpecialTokens.BOS, "</s>": SpecialTokens.EOS, "<unk>": SpecialTokens.UNK, "<mask>": SpecialTokens.MASK, }, ) return cls( columns=config.columns, vocab=vocab, tokenizer=tokenizer, max_seq_len=config.max_seq_len, )
def __init__( self, columns: List[str] = Config.columns, vocab: Optional[Vocabulary] = None, tokenizer: Optional[Tokenizer] = None, max_seq_len: int = Config.max_seq_len, ): RoBERTaTensorizer.__init__( self, columns=columns, vocab=vocab, tokenizer=tokenizer, max_seq_len=max_seq_len, )
[docs]class PositiveLabelTensorizerForDenseRetrieval(LabelTensorizer):
[docs] def numberize(self, row: Dict): return row["num_negative_ctx"]
[docs] def tensorize(self, batch): new_batch = [] for i in range(len(batch)): # batch[i - 1] = No. of -ve ctxs in previous example; +1 for +ve ctx pos_ctx_idx = i if i == 0 else new_batch[-1] + batch[i - 1] + 1 new_batch.append(pos_ctx_idx) return super().tensorize(new_batch)