Source code for pytext.torchscript.tensorizer.roberta

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List, Optional, Tuple

import torch
from pytext.torchscript.utils import pad_2d, pad_2d_mask

from .bert import ScriptBERTTensorizerBase


[docs]class ScriptRoBERTaTensorizer(ScriptBERTTensorizerBase): @torch.jit.script_method def _lookup_tokens(self, tokens: List[Tuple[str, int, int]]) -> List[int]: return self.vocab_lookup( tokens, bos_idx=self.vocab.bos_idx, eos_idx=self.vocab.eos_idx, use_eos_token_for_bos=False, max_seq_len=self.max_seq_len, )[0]
[docs]class ScriptRoBERTaTensorizerWithIndices(ScriptBERTTensorizerBase): @torch.jit.script_method def _lookup_tokens( self, tokens: List[Tuple[str, int, int]] ) -> Tuple[List[int], List[int], List[int]]: return self.vocab_lookup( tokens, bos_idx=self.vocab.bos_idx, eos_idx=self.vocab.eos_idx, use_eos_token_for_bos=False, max_seq_len=self.max_seq_len, ) @torch.jit.script_method def numberize( self, text_row: Optional[List[str]], token_row: Optional[List[List[str]]] ) -> Tuple[List[int], List[int], List[int], List[int]]: token_ids: List[int] = [] start_indices: List[int] = [] end_indices: List[int] = [] positions: List[int] = [] per_sentence_tokens: List[List[Tuple[str, int, int]]] = self.tokenize( text_row, token_row ) for idx, per_sentence_token in enumerate(per_sentence_tokens): lookup_ids, start_ids, end_ids = self._lookup_tokens(per_sentence_token) lookup_ids = self._wrap_numberized_tokens(lookup_ids, idx) token_ids.extend(lookup_ids) start_indices.extend(start_ids) end_indices.extend(end_ids) seq_len = len(token_ids) positions = list(range(seq_len)) return token_ids, start_indices, end_indices, positions @torch.jit.script_method def tensorize( self, texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[List[str]]]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: tokens_2d: List[List[int]] = [] start_indices_2d: List[List[int]] = [] end_indices_2d: List[List[int]] = [] positions_2d: List[List[int]] = [] for idx in range(self.batch_size(texts, tokens)): numberized: Tuple[ List[int], List[int], List[int], List[int] ] = self.numberize( self.get_texts_by_index(texts, idx), self.get_tokens_by_index(tokens, idx), ) tokens_2d.append(numberized[0]) start_indices_2d.append(numberized[1]) end_indices_2d.append(numberized[2]) positions_2d.append(numberized[3]) tokens, pad_mask = pad_2d_mask( tokens_2d, pad_value=self.vocab.pad_idx, seq_padding_control=self.seq_padding_control, max_seq_pad_len=self.max_seq_len, batch_padding_control=self.batch_padding_control, ) start_indices, _ = pad_2d_mask( start_indices_2d, pad_value=0, seq_padding_control=self.seq_padding_control, max_seq_pad_len=self.max_seq_len, batch_padding_control=self.batch_padding_control, ) end_indices, _ = pad_2d_mask( end_indices_2d, pad_value=0, seq_padding_control=self.seq_padding_control, max_seq_pad_len=self.max_seq_len, batch_padding_control=self.batch_padding_control, ) positions, _ = pad_2d_mask( positions_2d, pad_value=0, seq_padding_control=self.seq_padding_control, max_seq_pad_len=self.max_seq_len, batch_padding_control=self.batch_padding_control, ) if self.device == "": return tokens, pad_mask, start_indices, end_indices, positions else: return ( tokens.to(self.device), pad_mask.to(self.device), start_indices.to(self.device), end_indices.to(self.device), positions.to(self.device), )