Source code for pytext.torchscript.tokenizer.bpe

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import io
from typing import Dict, List

import torch
from pytext.torchscript.utils import utf8_chars
from pytext.utils.file_io import PathManager


[docs]class ScriptBPE(torch.jit.ScriptModule): """Byte-pair encoding implementation in TorchScript. vocab_file should be a file-like object separated by newlines, where each line consists of a word and a count separated by whitespace. Words in the vocab therefore can't contain space (according to python regex \\s). The vocab file should be sorted according to the importance of each token, and they will be merged in this priority; the actual score values are irrelevant. eow_token should be a string that is appended to the last character and token, and that token is used at each step in the process and returned at the end. You should set this to be consistent with the EOW signature used however you generated your ScriptBPE vocab file. >>> import io >>> vocab_file = io.StringIO(''' hello_EOW 20 world_EOW 18 th 17 is_EOW 16 bpe_EOW 15 ! 14 h 13 t 6 s_EOW 2 i -1 ii -2 ''') >>> bpe = ScriptBPE.from_vocab_file(vocab_file) >>> bpe.tokenize(["hello", "world", "this", "is", "bpe"]) ["hello_EOW", "world_EOW", "th", "is_EOW", "is_EOW", "bpe_EOW"] >>> bpe.tokenize(["iiiis"]) ["ii", "i", "is_EOW"] """ def __init__(self, vocab: Dict[str, int], eow: str = "_EOW"): """vocab is a dictionary from BPE segments, including any EOW elements, to their priority in joining. Priority must be an integer, should not be negative, and should not contain ties. In the case of negative priorities, segments with negative priorities will be ignored. In the case of ties, ties will be broken according to left-to-right byte order precedence, but this behavior isn't guaranteed and may change in the future. eow should be a string which corresponds to the EOW used in the vocab dictionary.""" super().__init__() self.vocab = torch.jit.Attribute(vocab, Dict[str, int]) self.eow = torch.jit.Attribute(eow, str)
[docs] @classmethod def from_vocab_file(cls, vocab_file: io.IOBase) -> "ScriptBPE": return cls(cls.load_vocab(vocab_file))
[docs] @classmethod def from_vocab_filename(cls, vocab_filename: str) -> "ScriptBPE": with PathManager.open(vocab_filename) as vocab_file: return cls(cls.load_vocab(vocab_file))
[docs] @staticmethod def load_vocab(file: io.IOBase) -> Dict[str, int]: def read_words(lines): for line in lines: if not line.strip(): continue yield line.strip().split(maxsplit=1)[0] words = list(read_words(file)) num_words = len(words) # We don't care about counts, except that we want them to be # non-negative and non-overlapping. We want to prioritize pairs # which come first in the vocab file. So ignore counts in the file # and score them according to reverse of their index in the file. return {word: num_words - i for i, word in enumerate(words)}
@torch.jit.script_method def bpe_token(self, token: str) -> List[str]: # If full token is in vocab, we're done. full_token = token + self.eow # `in` not implemented, this should be read `if full_token in self.vocab` if self.vocab.get(full_token) is not None: return [full_token] # Split word into parts, with the last part having EOW attached. # Any part (character or char + EOW) not in the vocab on its own # should be removed. EOW should always be attached to the last remaining # token. parts = utf8_chars(token) # parts and parts[-1] + self.eow not in self.vocab while len(parts) > 0 and self.vocab.get(parts[-1] + self.eow) is None: parts.pop() # The word consisted entirely of unknown characters if len(parts) == 0: return [self.eow] parts[-1] += self.eow # Remove any other obscure characters not in the vocab. # No easy way to iterate backwards or create descending ranges, # so using a while loop. i = 0 while i < len(parts): # parts[i] not in self.vocab if self.vocab.get(parts[i]) is None: parts.pop(i) else: i += 1 # We compare vocab dict scores to this value, so this is where we assume # vocab dict values are non-negative. NOT_IN_VOCAB = -1 # break not implemented should_break = False # Keep going until no more part pairs are in the vocab. # In obscure cases this could also get down to a single token, eg. if # we filter out some character and rebuild up to a single token. while len(parts) > 1 and not should_break: # Create part pairs, join part pair with highest score in vocab. # In pure python, this could be implemented as # max(range(len(parts) - 1), # key=lambda i: self.vocab.get(parts[i] + parts[i+1], -1))) max_pair_index = 0 max_pair_value = NOT_IN_VOCAB # We structure the vocabulary to not have ties, but they can come up anyway, # for instance in cases with repeated tokens or when passing in vocabs not # created with BPE.load_vocab. In the case of a tie between the value of # joined segments, they'll be joined proiritizing the first pair in the # token according to byte order, ie. left in LTR and right in RTL languages. # For instance, if the vocab contains "aa" but not "aaa", then # bpe_tokens("aaa") -> ["aa", "a"]. If the vocab contains "ab" and "bc" # mapped to the same priority, but not "abc", then # bpe_tokens("abc") -> ["ab", "c"]. for pair_index in range(len(parts) - 1): joined = parts[pair_index] + parts[pair_index + 1] pair_value = self.vocab.get(joined, NOT_IN_VOCAB) if pair_value > max_pair_value: max_pair_value = pair_value max_pair_index = pair_index if max_pair_value == NOT_IN_VOCAB: # No pairs found in vocab, we're done! should_break = True else: # break, continue not supported; only run this block if we wouldn't # want to break out after the above step # Combine parts pair with highest priority in vocab. # len(parts) shrinks by 1 each iteration, so we should be bounded # as linear in token length. # Subscript assignment not implemented. p1, p2 = parts[max_pair_index : max_pair_index + 2] parts = parts[:max_pair_index] + [p1 + p2] + parts[max_pair_index + 2 :] return parts @torch.jit.script_method def tokenize(self, tokens: List[str]) -> List[str]: bpe_tokens = torch.jit.annotate(List[str], []) for token in tokens: # extend not implemented for part in self.bpe_token(token): bpe_tokens.append(part) return bpe_tokens def __getstate__(self): """These implement pickling for ScriptBPE modules. TorchScript models can't be pickled normally. See https://github.com/pytorch/pytorch/issues/15116 for more context; in the meantime, for TorchScript modules that might want to be pickled (this one is often included in say tensorizer/tokenizer state that we want in snapshots) we need to implement a custom getstate and setstate for pickling. """ return {"vocab": self.vocab, "eow": self.eow} def __setstate__(self, state): ScriptBPE.__init__(self, state["vocab"], state["eow"])