#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, NamedTuple, Optional, Tuple
import torch
from torch import Tensor
# ===== the following section should be replaced once JIT provide native support
@torch.jit.script
def list_str_index(l: List[str], element: str, start: int) -> int:
"""
Equivalent to: list.index(v, start)
"""
for i, t in enumerate(l[start:]):
if t == element:
return start + i
return -1
@torch.jit.script
def validate_padding_control(padding_control: Optional[List[int]]) -> bool:
if padding_control is not None:
if len(padding_control) < 2:
return False
elif padding_control[0] != 0:
return False
return True
@torch.jit.script
def pad_length(
len: int, padding_control: Optional[List[int]], max_len: int = -1
) -> int:
if not validate_padding_control(padding_control):
raise NotImplementedError
if padding_control is not None:
for pad in padding_control:
if pad >= len:
len = pad
break
if max_len > 0:
len = min(len, max_len)
return len
@torch.jit.script
def reverse_tensor_list(int_list: List[torch.Tensor]) -> List[torch.Tensor]:
res = int_list[:] # semi-deep copy
res.reverse()
return res
@torch.jit.script
def pad_2d_mask(
input: List[List[int]],
pad_value: int = 0,
seq_padding_control: Optional[List[int]] = None,
max_seq_pad_len: int = -1,
batch_padding_control: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pad a list to a 2d tensor. Returns a pair of tensors, the padded tensor
as well as a mask tensor. The mask tensor has the same shape as the padded tensor,
with a 1 in the position of non-pad values and a 0 in the position of pads.
If padding_control is set, perform padding according to the specified padding style"""
# List comprehension required for TorchScript
max_seq_len = max([len(i) for i in input]) # noqa
max_seq_len = pad_length(max_seq_len, seq_padding_control, max_seq_pad_len)
max_batch_len = len(input)
max_batch_len = pad_length(max_batch_len, batch_padding_control, -1)
tensor = torch.full((max_batch_len, max_seq_len), pad_value)
for i in range(len(input)):
numel_to_copy = min(len(input[i]), max_seq_len)
tensor[i][:numel_to_copy] = torch.tensor(input[i][:numel_to_copy])
mask = tensor.ne(pad_value).to(torch.long)
return tensor, mask
@torch.jit.script
def pad_float_tensor(
input_tensor: torch.Tensor,
batch_padding_control: Optional[List[int]] = None,
) -> torch.Tensor:
"""Pad a tensor to with zeros given batch_padding_control, first dimention is batch size."""
shape = input_tensor.shape
bs = shape[0]
max_batch_len = pad_length(bs, batch_padding_control, -1)
shape[0] = max_batch_len - bs
pad_tensor = torch.zeros(shape, dtype=torch.float)
new_tensor = torch.cat([input_tensor, pad_tensor], 0)
return new_tensor
# ========= end section
@torch.jit.script
def utf8_chars(s: str) -> List[str]:
"""An implementation of UTF8 character iteration in TorchScript.
There are no bitwise operations in torchscript, so we compare directly to
integer values. There isn't a lot of validation, for instance if you pass
in an improperly encoded string with an out-of-place continuation byte,
or with a non-left-to-right byte order, you'll get unexpected results
and likely throw. Torch itself takes in unicode strings and encodes them
as UTF8, so that should be actively hard to do.
The logic is simple: looking at the current start-of-character byte.
If its high bit is 0, it's a 1-byte character. Otherwise, the number of
bytes is the number of leading 1s in its binary representation, so
find that number by comparing it directly to ints with the appropriate
representation, then append that many bytes as a character and move past
them to the next start byte.
"""
chars = torch.jit.annotate(List[str], [])
i = 0
while i < len(s):
byte = ord(s[i])
if byte < 0b10000000:
chars.append(s[i])
i += 1
else:
if byte < 0b11100000:
num_bytes = 2
elif byte < 0b11110000:
num_bytes = 3
elif byte < 0b11111000:
num_bytes = 4
elif byte < 0b11111100:
num_bytes = 5
elif byte < 0b11111110:
num_bytes = 6
elif byte < 0b11111111:
num_bytes = 7
else:
num_bytes = 8
chars.append(s[i : i + num_bytes])
i += num_bytes
return chars
@torch.jit.script
def truncate_tokens(
batch: List[List[str]], max_seq_len: int, pad_token: str
) -> List[List[str]]:
truncated: List[List[str]] = []
for sentence in batch:
if len(sentence) == 0:
truncated.append([pad_token])
elif max_seq_len > 0:
truncated.append(sentence[0:max_seq_len])
else:
truncated.append(sentence)
return truncated
@torch.jit.script
def make_sequence_lengths(batch: List[List[str]]) -> List[int]:
seq_lens = torch.jit.annotate(List[int], [])
for sentence in batch:
seq_lens.append(len(sentence))
return seq_lens
@torch.jit.script
def pad_2d(
batch: List[List[int]], seq_lens: List[int], pad_idx: int, max_len: int = -1
) -> List[List[int]]:
pad_to_length = max(seq_lens)
if max_len > 0:
pad_to_length = min(pad_to_length, max_len)
for sentence in batch:
padding = pad_to_length - len(sentence)
if padding >= 0:
for _ in range(padding):
sentence.append(pad_idx)
else:
for _ in range(-padding):
sentence.pop()
return batch
@torch.jit.script
def pad_3d(
batch: List[List[List[int]]],
tokens_lengths: List[List[int]],
pad_idx: int,
) -> Tuple[List[List[List[int]]], List[List[int]]]:
pad_to_1d: int = 0
pad_to_2d: int = 0
for tokens_length in tokens_lengths:
pad_to_1d = max(pad_to_1d, len(tokens_length))
pad_to_2d = max(pad_to_2d, max(tokens_length))
for sentence, sentence_len in zip(batch, tokens_lengths):
for _ in range(pad_to_1d - len(sentence)):
new_list: List[int] = []
sentence.append(new_list)
sentence_len.append(0)
for token in sentence:
for _ in range(pad_to_2d - len(token)):
token.append(pad_idx)
return batch, tokens_lengths
@torch.jit.script
def pad_2d_float(
batch: List[List[float]],
seq_lens: List[int],
pad_val: float = 0.0,
max_len: int = -1,
) -> List[List[float]]:
pad_to_length = max(seq_lens)
if max_len > 0:
pad_to_length = min(pad_to_length, max_len)
for sentence in batch:
padding = pad_to_length - len(sentence)
if padding >= 0:
for _ in range(padding):
sentence.append(pad_val)
else:
for _ in range(-padding):
sentence.pop()
return batch
@torch.jit.script
def pad_3d_float(
batch: List[List[List[float]]], seq_lens: List[int], pad_val: float = 0.0
) -> List[List[List[float]]]:
outer_pad_to_length = max(seq_lens)
inner_pad_to_length = -1
for outer_list in batch:
for inner_list in outer_list:
inner_pad_to_length = max(inner_pad_to_length, len(inner_list))
for outer_list in batch:
for inner_list in outer_list:
for _ in range(inner_pad_to_length - len(inner_list)):
inner_list.append(pad_val)
for _ in range(outer_pad_to_length - len(outer_list)):
outer_list.append([pad_val] * inner_pad_to_length)
return batch
@torch.jit.script
def add_special_token_2d(
values: List[List[int]], special_token: int = 0, use_first_value: bool = False
) -> List[List[int]]:
results = torch.jit.annotate(List[List[int]], [])
for value in values:
result = torch.jit.annotate(List[int], [])
if use_first_value and len(value) > 0:
special_token = value[0]
result.append(special_token)
result.extend(value)
result.append(special_token)
results.append(result)
return results
@torch.jit.script
def add_bos_eos_2d(values: List[List[int]], bos: int, eos: int) -> List[List[int]]:
results = torch.jit.annotate(List[List[int]], [])
for value in values:
result = torch.jit.annotate(List[int], [])
result.append(bos)
result.extend(value)
result.append(eos)
results.append(result)
return results
@torch.jit.script
def make_byte_inputs(
batch: List[List[str]], max_byte_len: int, offset_for_non_padding: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_lens = make_sequence_lengths(batch)
max_num_tokens = max(seq_lens)
bytes = torch.zeros(len(batch), max_num_tokens, max_byte_len, dtype=torch.long)
for batch_index in range(len(batch)):
sentence = batch[batch_index]
for token_index in range(len(sentence)):
token = sentence[token_index]
for byte_index in range(min(len(token), max_byte_len)):
s = token[byte_index]
# use empty string as eos because torchscript not support chr()
if s == "":
v = 256
else:
v = ord(s)
# add offset_for_non_padding to conform to Fairseq pre-training
bytes[batch_index][token_index][byte_index] = v + offset_for_non_padding
return bytes, torch.tensor(seq_lens)
@torch.jit.script
def squeeze_1d(inputs: Optional[List[str]]) -> Optional[List[List[str]]]:
result: Optional[List[List[str]]] = None
if inputs is not None:
result = torch.jit.annotate(List[List[str]], [])
for line in inputs:
result.append([line])
return result
@torch.jit.script
def squeeze_2d(inputs: Optional[List[List[str]]]) -> Optional[List[List[List[str]]]]:
result: Optional[List[List[List[str]]]] = None
if inputs is not None:
result = torch.jit.annotate(List[List[List[str]]], [])
for line in inputs:
result.append([line])
return result
[docs]def float_tensor_list1D(input_tensor: Tensor) -> List[float]:
result: List[float] = []
assert len(input_tensor.size()) == 1
for idx in range(input_tensor.size(0)):
result.append(float(input_tensor[idx]))
return result