Source code for pytext.data.packed_lm_data

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Dict, List, Optional, Type

from pytext.common.constants import Stage
from pytext.data import Batcher, Data
from pytext.data.bert_tensorizer import BERTTensorizerBase
from pytext.data.data import RowData
from pytext.data.sources import DataSource
from pytext.data.tensorizers import Tensorizer, TokenTensorizer


[docs]class PackedLMData(Data): """ Special purpose Data object which assumes a single text tensorizer. Packs tokens into a square batch with no padding. Used for LM training. The object also takes in an optional language argument which is used for cross-lingual LM training. """ __EXPANSIBLE__ = True
[docs] class Config(Data.Config): max_seq_len: int = 128
[docs] @classmethod def from_config( cls, config: Config, schema: Dict[str, Type], tensorizers: Dict[str, Tensorizer], language: Optional[str] = None, rank: int = 0, world_size: int = 1, init_tensorizers: Optional[bool] = True, ): return super(PackedLMData, cls).from_config( config, schema, tensorizers, rank, world_size, language=language, max_seq_len=config.max_seq_len, init_tensorizers=init_tensorizers, )
def __init__( self, data_source: DataSource, tensorizers: Dict[str, Tensorizer], batcher: Batcher = None, max_seq_len: int = Config.max_seq_len, sort_key: Optional[str] = None, # language is used in cross-lingual LM training language: Optional[str] = None, in_memory: Optional[bool] = False, init_tensorizers: Optional[bool] = True, ): super().__init__( data_source, tensorizers, batcher, sort_key, in_memory, init_tensorizers ) assert len(list(self.tensorizers.items())) == 1 self.tensorizer_name, self.tensorizer = list(self.tensorizers.items())[0] self.remainder: Dict[str, List[int]] = {"tokens": [], "segment_labels": []} self.max_seq_len = max_seq_len self.language = language self.batch = {Stage.TRAIN: None, Stage.EVAL: None, Stage.TEST: None} def _parse_row(self, row): """ The output of numberization has different number of elements depending on the tensorizer used. For example: positions tensor is only output by the XLMTensorizer. This function unpacks the elements according to the specific tensorizer used. Additionally, since we are packing tokens into fixed size blocks, we don't need to use the positions vector output by the call to numberize. We will simply create this in `_format_output_row`. """ numberized_row = self.tensorizer.numberize(row) if isinstance(self.tensorizer, BERTTensorizerBase): tokens, segment_labels, seq_len, _ = numberized_row elif isinstance(self.tensorizer, TokenTensorizer): tokens, seq_len, _ = numberized_row segment_labels = [] else: raise NotImplementedError( "PackedLMData only supports XLMTensorizer, BERTTensorizer and " "TokenTensorizer." ) return tokens, segment_labels, seq_len def _format_output_row(self, tokens, segment_labels, seq_len): """ The tensorize function for different tensorizers takes in different number of inputs which may be arranged differently. This function formats the output dict to conform to the expectations of the tensorizer. In case of the XLMTensorizer, we also need to create a new positions list which goes from 0 to seq_len. """ if isinstance(self.tensorizer, BERTTensorizerBase): positions = [index for index in range(seq_len)] return {self.tensorizer_name: (tokens, segment_labels, seq_len, positions)} elif isinstance(self.tensorizer, TokenTensorizer): # dummy token_ranges return {self.tensorizer_name: (tokens, seq_len, [(-1, -1)] * seq_len)} else: raise NotImplementedError( "PackedLMData only supports BERTTensorizer and TokenTensorizer." ) def _yield_and_reset(self, row): packed_tokens = list(self.remainder["tokens"]) packed_segments = list(self.remainder["segment_labels"]) self.remainder: Dict[str, List[int]] = {"tokens": [], "segment_labels": []} return RowData( row, self._format_output_row(packed_tokens, packed_segments, len(packed_tokens)), )
[docs] def numberize_rows(self, rows): last_row = None """ This function does the actual packing. It processes rows until we obtain a block of data with length = max_seq_len. """ for row in rows: last_row = row # if the packedLM object has a language member then a cross-lingual # LM is being trained using monolingual data. # Add this language to the row since the underlying # tensorizer needs this to generate language embeddings (used as # segment_labels below) if self.language: row["language"] = self.language tokens, segment_labels, seq_len = self._parse_row(row) remaining = self.max_seq_len - len(self.remainder["tokens"]) - 1 while remaining < len(tokens): self.remainder["tokens"].extend(tokens[:remaining]) self.remainder["segment_labels"].extend(segment_labels[:remaining]) tokens = tokens[remaining:] segment_labels = segment_labels[remaining:] # packed LM data doesn't respect data cardinality, # therefore, it stores the row at the start position, # instead of the exact corresponding row. yield self._yield_and_reset(row) remaining = self.max_seq_len - 1 self.remainder["tokens"].extend(tokens) self.remainder["segment_labels"].extend(segment_labels) if len(self.remainder["tokens"]): yield self._yield_and_reset(last_row)