Source code for pytext.torchscript.utils

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from enum import Enum
from typing import List, Optional, Tuple

import torch


[docs]class ScriptInputType(Enum): text = 1 # row contains a single sentence token = 2 # row contains a list of tokens from a single sentence
[docs] def is_text(self): return self is ScriptInputType.text
[docs] def is_token(self): return self is ScriptInputType.token
# ===== the following section should be replaced once JIT provide native support @torch.jit.script def list_max(l: List[int]): max_value = l[0] # fine to throw if empty for i in range(len(l) - 1): # don't forget the +1 max_value = max(max_value, l[i + 1]) return max_value @torch.jit.script def list_str_index(l: List[str], element: str, start: int) -> int: """ Equivalent to: list.index(v, start) """ for i, t in enumerate(l[start:]): if t == element: return start + i return -1 @torch.jit.script def list_membership(item: int, list: List[int]): item_present = False for i in list: if item == i: item_present = True return item_present @torch.jit.script def reverse_tensor_list(int_list: List[torch.Tensor]) -> List[torch.Tensor]: l_len = len(int_list) res = [] for idx in range(l_len): res.append(int_list[l_len - idx - 1]) return res @torch.jit.script def long_tensor_2d(shape: Tuple[int, int], fill_value: int = 0) -> torch.Tensor: """Return a new 2d torch.LongTensor with size according to shape. The values of this tensor will be fill_value.""" outer = torch.jit.annotate(List[List[int]], []) inner = torch.jit.annotate(List[int], []) for _i in range(shape[1]): inner.append(fill_value) for _i in range(shape[0]): outer.append(inner) return torch.tensor(outer, dtype=torch.long) @torch.jit.script def pad_2d_mask( input: List[List[int]], pad_value: int = 0 ) -> Tuple[torch.Tensor, torch.Tensor]: """Pad a list to a 2d tensor. Returns a pair of tensors, the padded tensor as well as a mask tensor. The mask tensor has the same shape as the padded tensor, with a 1 in the position of non-pad values and a 0 in the position of pads.""" max_len = 0 for i in input: max_len = max(max_len, len(i)) tensor = long_tensor_2d((len(input), max_len), pad_value) mask = long_tensor_2d((len(input), max_len), 0) for i in range(len(input)): for j in range(len(input[i])): tensor[i][j] = input[i][j] mask[i][j] = 1 return tensor, mask # ========= end section @torch.jit.script def utf8_chars(s: str) -> List[str]: """An implementation of UTF8 character iteration in TorchScript. There are no bitwise operations in torchscript, so we compare directly to integer values. There isn't a lot of validation, for instance if you pass in an improperly encoded string with an out-of-place continuation byte, or with a non-left-to-right byte order, you'll get unexpected results and likely throw. Torch itself takes in unicode strings and encodes them as UTF8, so that should be actively hard to do. The logic is simple: looking at the current start-of-character byte. If its high bit is 0, it's a 1-byte character. Otherwise, the number of bytes is the number of leading 1s in its binary representation, so find that number by comparing it directly to ints with the appropriate representation, then append that many bytes as a character and move past them to the next start byte. """ chars = torch.jit.annotate(List[str], []) i = 0 while i < len(s): byte = ord(s[i]) if byte < 0b10000000: chars.append(s[i]) i += 1 else: if byte < 0b11100000: num_bytes = 2 elif byte < 0b11110000: num_bytes = 3 elif byte < 0b11111000: num_bytes = 4 elif byte < 0b11111100: num_bytes = 5 elif byte < 0b11111110: num_bytes = 6 elif byte < 0b11111111: num_bytes = 7 else: num_bytes = 8 chars.append(s[i : i + num_bytes]) i += num_bytes return chars @torch.jit.script def make_sequence_lengths(batch: List[List[str]]) -> List[int]: seq_lens = torch.jit.annotate(List[int], []) for sentence in batch: seq_lens.append(len(sentence)) return seq_lens @torch.jit.script def pad_2d( batch: List[List[int]], seq_lens: List[int], pad_idx: int ) -> List[List[int]]: pad_to_length = list_max(seq_lens) for sentence in batch: for _ in range(pad_to_length - len(sentence)): sentence.append(pad_idx) return batch @torch.jit.script def add_special_token_2d( values: List[List[int]], special_token: int = 0, use_first_value: bool = False ) -> List[List[int]]: results = torch.jit.annotate(List[List[int]], []) for value in values: result = torch.jit.annotate(List[int], []) if use_first_value and len(value) > 0: special_token = value[0] result.append(special_token) result.extend(value) result.append(special_token) results.append(result) return results @torch.jit.script def add_bos_eos_2d(values: List[List[int]], bos: int, eos: int) -> List[List[int]]: results = torch.jit.annotate(List[List[int]], []) for value in values: result = torch.jit.annotate(List[int], []) result.append(bos) result.extend(value) result.append(eos) results.append(result) return results @torch.jit.script def make_byte_inputs( batch: List[List[str]], max_byte_len: int, offset_for_non_padding: int = 0 ) -> Tuple[torch.Tensor, torch.Tensor]: seq_lens = make_sequence_lengths(batch) max_num_tokens = list_max(seq_lens) bytes = torch.zeros(len(batch), max_num_tokens, max_byte_len, dtype=torch.long) for batch_index in range(len(batch)): sentence = batch[batch_index] for token_index in range(len(sentence)): token = sentence[token_index] for byte_index in range(min(len(token), max_byte_len)): s = token[byte_index] # use empty string as eos because torchscript not support chr() if s == "": v = 256 else: v = ord(s) # add offset_for_non_padding to conform to Fairseq pre-training bytes[batch_index][token_index][byte_index] = v + offset_for_non_padding return bytes, torch.tensor(seq_lens) @torch.jit.script def squeeze_1d(inputs: Optional[List[str]]) -> Optional[List[List[str]]]: result: Optional[List[List[str]]] = None if inputs is not None: result = torch.jit.annotate(List[List[str]], []) for line in inputs: result.append([line]) return result @torch.jit.script def squeeze_2d(inputs: Optional[List[List[str]]]) -> Optional[List[List[List[str]]]]: result: Optional[List[List[List[str]]]] = None if inputs is not None: result = torch.jit.annotate(List[List[List[str]]], []) for line in inputs: result.append([line]) return result