Source code for

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, List, Optional, Tuple

import torch
from import MaskedLMDictionary
from pytext.common.constants import SpecialTokens
from pytext.config.component import ComponentType, create_component
from import (
from import Tokenizer
from import Vocabulary
from import LANG2ID_15
from pytext.torchscript.utils import ScriptBatchInput
from pytext.torchscript.vocab import ScriptVocabulary
from pytext.utils.file_io import PathManager
from pytext.utils.lazy import lazy_property

[docs]class XLMTensorizerScriptImpl(BERTTensorizerBaseScriptImpl): def __init__( self, tokenizer: Tokenizer, vocab: Vocabulary, max_seq_len: int, language_vocab: List[str], default_language: str, ): super().__init__(tokenizer, vocab, max_seq_len) self.language_vocab = ScriptVocabulary(language_vocab) self.default_language = torch.jit.Attribute(default_language, str) def _lookup_tokens( self, tokens: List[Tuple[str, int, int]], max_seq_len: Optional[int] = None ) -> Tuple[List[int], List[int], List[int]]: if max_seq_len is None: max_seq_len = self.max_seq_len return self.vocab_lookup( tokens, bos_idx=self.vocab.eos_idx, eos_idx=self.vocab.eos_idx, use_eos_token_for_bos=True, max_seq_len=max_seq_len, )
[docs] def numberize( self, per_sentence_tokens: List[List[Tuple[str, int, int]]], per_sentence_languages: List[int], ) -> Tuple[List[int], List[int], int, List[int]]: tokens: List[int] = [] segment_labels: List[int] = [] # e.g language_ids seq_len: int = 0 positions: List[int] = [] max_seq_len: int = self.max_seq_len // len(per_sentence_tokens) for idx, single_sentence_tokens in enumerate(per_sentence_tokens): lookup_ids: List[int] = self._lookup_tokens( single_sentence_tokens, max_seq_len=max_seq_len )[0] lookup_ids = self._wrap_numberized_tokens(lookup_ids, idx) tokens.extend(lookup_ids) segment_labels.extend([per_sentence_languages[idx]] * len(lookup_ids)) seq_len = len(tokens) positions = [i for i in range(seq_len)] return tokens, segment_labels, seq_len, positions
[docs] def forward( self, inputs: ScriptBatchInput ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Wire up tokenize(), numberize() and tensorize() functions for data processing. """ batch_size: int = self.batch_size(inputs) row_size: int = self.row_size(inputs) languages: Optional[List[List[str]]] = inputs.languages if languages is None: languages = [[self.default_language] * row_size] * batch_size tokens_2d: List[List[int]] = [] segment_labels_2d: List[List[int]] = [] seq_lens_1d: List[int] = [] positions_2d: List[List[int]] = [] for idx in range(self.batch_size(inputs)): tokens: List[List[Tuple[str, int, int]]] = self.tokenize( self.get_texts_by_index(inputs.texts, idx), self.get_tokens_by_index(inputs.tokens, idx), ) language_ids: List[int] = [ self.language_vocab.idx.get( languages[idx][0], self.language_vocab.unk_idx ) ] * row_size numberized: Tuple[List[int], List[int], int, List[int]] = self.numberize( tokens, language_ids ) tokens_2d.append(numberized[0]) segment_labels_2d.append(numberized[1]) seq_lens_1d.append(numberized[2]) positions_2d.append(numberized[3]) return self.tensorize(tokens_2d, segment_labels_2d, seq_lens_1d, positions_2d)
[docs]class XLMTensorizer(BERTTensorizerBase): """ Tensorizer for Cross-lingual LM tasks. Works for single sentence as well as sentence pair. """ __TENSORIZER_SCRIPT_IMPL__ = XLMTensorizerScriptImpl class Config(BERTTensorizerBase.Config): vocab_file: str = "/mnt/vol/nlp_technologies/xlm/vocab_xnli_15" tokenizer: Tokenizer.Config = Tokenizer.Config() max_vocab: int = 95000 min_count: int = 0 # language identifiers for extracting the language from a row of data # during numberize language_column: str = "language" # language-to-id mapping used to obtain language embeddings lang2id: Dict[str, int] = LANG2ID_15 # Controls whether language is being read from the data file (which # is what happens for finetuning) or being added during processing # (which is what happens during pretraining) has_language_in_data: bool = False # controls whether we train with language embeddings or not use_language_embeddings: bool = True
[docs] @classmethod def from_config(cls, config: Config): tokenizer = create_component(ComponentType.TOKENIZER, config.tokenizer) with as file_path: vocab = build_fairseq_vocab( dictionary_class=MaskedLMDictionary, vocab_file=file_path, max_vocab=config.max_vocab, min_count=config.min_count, special_token_replacements={ "<unk>": SpecialTokens.UNK, "<pad>": SpecialTokens.PAD, "</s>": SpecialTokens.EOS, "<mask>": SpecialTokens.MASK, }, ) return cls( columns=config.columns, vocab=vocab, tokenizer=tokenizer, max_seq_len=config.max_seq_len, language_column=config.language_column, lang2id=config.lang2id, use_language_embeddings=config.use_language_embeddings, has_language_in_data=config.has_language_in_data, )
def __init__( self, columns: List[str] = Config.columns, vocab: Vocabulary = None, tokenizer: Tokenizer = None, max_seq_len: int = Config.max_seq_len, language_column: str = Config.language_column, lang2id: Dict[str, int] = Config.lang2id, use_language_embeddings: bool = Config.use_language_embeddings, has_language_in_data: bool = Config.has_language_in_data, ) -> None: assert len(columns) <= 2, "More than 2 text fields are not supported." super().__init__( columns=columns, vocab=vocab, tokenizer=tokenizer, max_seq_len=max_seq_len ) self.language_column = language_column self.lang2id = lang2id self.use_language_embeddings = use_language_embeddings self.has_language_in_data = has_language_in_data # unlike BERT, XLM uses the EOS token for both beginning and end of # sentence self.bos_token = self.vocab.eos_token self.default_language = "en" @property def column_schema(self): schema = super().column_schema if self.has_language_in_data: schema += [(self.language_column, str)] return schema @lazy_property def tensorizer_script_impl(self): languages = [0] * (max(list(self.lang2id.values())) + 1) for k, v in self.lang2id.items(): languages[v] = k return self.__TENSORIZER_SCRIPT_IMPL__( tokenizer=self.tokenizer, vocab=self.vocab, language_vocab=languages, max_seq_len=self.max_seq_len, default_language=self.default_language, )
[docs] def get_lang_id(self, row: Dict, col: str) -> int: # generate lang embeddings. if training without lang embeddings, use # the first language as the lang_id (there will always be one lang) if self.use_language_embeddings: lang = row[col] assert lang in self.lang2id, f"language {lang} not supported in {row}" lang_id = self.lang2id[lang] return lang_id else: # use En as default return self.lang2id.get(self.default_language, 0)
[docs] def numberize(self, row: Dict) -> Tuple[Any, ...]: per_sentence_tokens = [ self.tokenizer.tokenize(row[column]) for column in self.columns ] per_sentence_languages = [self.get_lang_id(row, self.language_column)] * len( self.columns ) return self.tensorizer_script_impl.numberize( per_sentence_tokens, per_sentence_languages )