Source code for pytext.torchscript.tensorizer.xlm

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List, Optional, Tuple

import torch
from pytext.torchscript.utils import pad_2d_mask
from pytext.torchscript.vocab import ScriptVocabulary

from .tensorizer import ScriptTensorizer, VocabLookup


[docs]class ScriptXLMTensorizer(ScriptTensorizer): def __init__( self, tokenizer: torch.jit.ScriptModule, token_vocab: ScriptVocabulary, language_vocab: ScriptVocabulary, max_seq_len: int, default_language: str, ): super().__init__() self.tokenizer = tokenizer self.token_vocab = token_vocab self.language_vocab = language_vocab self.token_vocab_lookup = VocabLookup(token_vocab) self.language_vocab_lookup = VocabLookup(language_vocab) self.max_seq_len = torch.jit.Attribute(max_seq_len, int) self.default_language = torch.jit.Attribute(default_language, str) @torch.jit.script_method def tokenize( self, text_row: Optional[List[str]], token_row: Optional[List[List[str]]], language_row: List[str], ) -> Tuple[List[List[Tuple[str, int, int]]], List[List[Tuple[str, int, int]]]]: per_sentence_tokens: List[List[Tuple[str, int, int]]] = [] per_sentence_languages: List[List[Tuple[str, int, int]]] = [] if text_row is not None: """ Tokenize every single text into a list of tokens. For example: text_row = ["hello world", "this is sentence"] per_sentence_tokens = [["hello", "world"], ["this", "is", "sentence"]] """ for idx, text in enumerate(text_row): sentence_tokens: List[Tuple[str, int, int]] = self.tokenizer.tokenize( text ) sentence_languages: List[Tuple[str, int, int]] = [ (language_row[idx], token[1], token[2]) for token in sentence_tokens ] per_sentence_tokens.append(sentence_tokens) per_sentence_languages.append(sentence_languages) elif token_row is not None: """ Tokenize every single token into a sub tokens. (example: BPE) For example: token_row = [["hello", "world"], ["this", "is", "sentence"]] per_sentence_tokens = [ ["he", "llo" "wo", "rld"], ["th", "is", "is", "sen", "tence"] ] """ for idx, sentence_raw_tokens in enumerate(token_row): sentence_tokens: List[Tuple[str, int, int]] = [] sentence_languages: List[Tuple[str, int, int]] = [] for raw_token in sentence_raw_tokens: sub_tokens: List[Tuple[str, int, int]] = self.tokenizer.tokenize( raw_token ) sub_languages: List[Tuple[str, int, int]] = [ (language_row[idx], token[1], token[2]) for token in sub_tokens ] sentence_tokens.extend(sub_tokens) sentence_languages.extend(sub_languages) per_sentence_tokens.append(sentence_tokens) per_sentence_languages.append(sentence_languages) return per_sentence_tokens, per_sentence_languages @torch.jit.script_method def _lookup_tokens( self, tokens: List[Tuple[str, int, int]], languages: List[Tuple[str, int, int]], max_seq_len: int, ) -> Tuple[List[int], List[int]]: token_ids: List[int] = self.token_vocab_lookup( tokens, bos_idx=self.token_vocab.eos_idx, eos_idx=self.token_vocab.eos_idx, use_eos_token_for_bos=True, max_seq_len=max_seq_len, )[0] language_special_idx: int = self.language_vocab.idx.get( languages[0][0], self.language_vocab.unk_idx ) language_ids = self.language_vocab_lookup( languages, bos_idx=language_special_idx, eos_idx=language_special_idx, use_eos_token_for_bos=True, max_seq_len=max_seq_len, )[0] return token_ids, language_ids @torch.jit.script_method def numberize( self, text_row: Optional[List[str]], token_row: Optional[List[List[str]]], language_row: List[str], ) -> Tuple[List[int], List[int], int, List[int]]: per_sentence_tokens, per_sentence_languages = self.tokenize( text_row, token_row, language_row ) token_ids: List[int] = [] language_ids: List[int] = [] max_seq_len: int = self.max_seq_len // len(per_sentence_tokens) # concatinate tokens from each text in the same row into a single list of tokens for idx in range(len(per_sentence_tokens)): lookup_token_ids, lookup_language_ids = self._lookup_tokens( per_sentence_tokens[idx], per_sentence_languages[idx], max_seq_len ) token_ids.extend(lookup_token_ids) language_ids.extend(lookup_language_ids) seq_len: int = len(token_ids) positions: List[int] = [i for i in range(seq_len)] return token_ids, language_ids, seq_len, positions @torch.jit.script_method def tensorize( self, texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[List[str]]]] = None, languages: Optional[List[List[str]]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # unwrap Optional batch_size: int = self.batch_size(texts, tokens) row_size: int = self.row_size(texts, tokens) if languages is None: languages = [[self.default_language] * row_size] * batch_size tokens_2d: List[List[int]] = [] languages_2d: List[List[int]] = [] seq_len_2d: List[int] = [] positions_2d: List[List[int]] = [] for idx in range(batch_size): numberized: Tuple[List[int], List[int], int, List[int]] = self.numberize( self.get_texts_by_index(texts, idx), self.get_tokens_by_index(tokens, idx), languages[idx], ) tokens_2d.append(numberized[0]) languages_2d.append(numberized[1]) seq_len_2d.append(numberized[2]) positions_2d.append(numberized[3]) tokens, pad_mask = pad_2d_mask( tokens_2d, pad_value=self.token_vocab.pad_idx, seq_padding_control=self.seq_padding_control, max_seq_pad_len=self.max_seq_len, batch_padding_control=self.batch_padding_control, ) languages, _ = pad_2d_mask( languages_2d, pad_value=0, seq_padding_control=self.seq_padding_control, max_seq_pad_len=self.max_seq_len, batch_padding_control=self.batch_padding_control, ) positions, _ = pad_2d_mask( positions_2d, pad_value=0, seq_padding_control=self.seq_padding_control, max_seq_pad_len=self.max_seq_len, batch_padding_control=self.batch_padding_control, ) if self.device == "": return tokens, pad_mask, languages, positions else: return ( tokens.to(self.device), pad_mask.to(self.device), languages.to(self.device), positions.to(self.device), )