Source code for pytext.data.tokenizers.tokenizer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import json
import re
from collections import OrderedDict
from typing import List, NamedTuple, Union, Optional

from fairseq.data.encoders.gpt2_bpe import get_encoder as create_gpt2_bpe
from fairseq.data.encoders.gpt2_bpe_utils import Encoder as GPT2BPEEncoder
from pytext.config import ConfigBase
from pytext.config.component import Component, ComponentType, create_component
from pytext.torchscript.tokenizer import ScriptDoNothingTokenizer, ScriptWordTokenizer
from pytext.utils.file_io import PathManager
from pytext.utils.usage import log_class_usage
from sentencepiece import SentencePieceProcessor
from transformers.tokenization_bert import (
    BasicTokenizer,
    WordpieceTokenizer,
)


[docs]class Token(NamedTuple): value: str start: int end: int
[docs]class Tokenizer(Component): """A simple regex-splitting tokenizer.""" __COMPONENT_TYPE__ = ComponentType.TOKENIZER __EXPANSIBLE__ = True
[docs] class Config(Component.Config): #: A regular expression for the tokenizer to split on. Tokens are the segments #: between the regular expression matches. The start index is inclusive of the #: unmatched region, and the end index is exclusive (matching the first #: character of the matched split region). split_regex: str = r"\s+" #: Whether token values should be lowercased or not. lowercase: bool = True #: Whether to use utf8 byte offsets use_byte_offsets: bool = False
[docs] @classmethod def from_config(cls, config: Config): return cls(config.split_regex, config.lowercase, config.use_byte_offsets)
def __init__(self, split_regex=r"\s+", lowercase=True, use_byte_offsets=False): super().__init__(None) self.split_regex = split_regex self.lowercase = lowercase self.use_byte_offsets = use_byte_offsets
[docs] def tokenize(self, input: str) -> List[Token]: tokens = [] start = 0 tokenize_input = input.lower() if self.lowercase else input for match in re.finditer(self.split_regex, tokenize_input): split_start, split_end = match.span() tokens.append(Token(tokenize_input[start:split_start], start, split_start)) start = split_end tokens.append(Token(tokenize_input[start : len(input)], start, len(input))) if self.use_byte_offsets: return [ self._convert_token(input, token) for token in tokens if token.value ] else: return [token for token in tokens if token.value]
def _convert_token(self, inp: str, token: Token) -> Token: return Token( token.value, self._convert_char_to_byte_offsets(inp, token.start), self._convert_char_to_byte_offsets(inp, token.end), ) def _convert_char_to_byte_offsets(self, input: str, char_offset: int) -> int: return len(input[:char_offset].encode("utf8"))
[docs] def torchscriptify(self): # torchscriptify only supports space spliting tokenizer if self.split_regex == r"\s+": return ScriptWordTokenizer(self.lowercase) else: NotImplementedError
[docs] def decode(self, sentence: str): ## To be overridden by subword level tokenizers to convert to string return sentence
[docs]class DoNothingTokenizer(Tokenizer): """ Tokenizer that takes a list of strings and converts to a list of Tokens. Useful in cases where tokenizer is run before-hand """
[docs] class Config(Component.Config): do_nothing: str = ""
[docs] @classmethod def from_config(cls, config: Config): return cls()
def __init__(self): super().__init__(None)
[docs] def tokenize(self, tokens: Union[List[str], str]) -> List[Token]: if isinstance(tokens, str): tokens = json.loads(tokens) tokens = [Token(token_text, -1, -1) for token_text in tokens if token_text] return tokens
[docs] def torchscriptify(self): return ScriptDoNothingTokenizer()
[docs]class BERTInitialTokenizer(Tokenizer): """ Basic initial tokenization for BERT. This is run prior to word piece, does white space tokenization in addition to lower-casing and accent removal if specified. """
[docs] class Config(Tokenizer.Config): """Config for this class."""
[docs] @classmethod def from_config(cls, config: Config): basic_tokenizer = BasicTokenizer( do_lower_case=config.lowercase, never_split=( "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]", ), # compatibility with HF v0.5 ) return cls(basic_tokenizer)
def __init__(self, basic_tokenizer) -> None: self.tokenizer = basic_tokenizer log_class_usage(__class__)
[docs] def tokenize(self, text): """Tokenizes a piece of text.""" if self.tokenizer.do_lower_case: text = self.tokenizer._run_strip_accents(text.lower()) tokens = self.tokenizer.tokenize(text) end = 0 result = [] for token in tokens: start = text.find(token, end) if start == -1: # safety check, this should not happen start = end end = start + len(token) result.append(Token(token, start, end)) return result
[docs]class WordPieceTokenizer(Tokenizer): """Word piece tokenizer for BERT models."""
[docs] class Config(ConfigBase): basic_tokenizer: BERTInitialTokenizer.Config = BERTInitialTokenizer.Config() wordpiece_vocab_path: str = "manifold://nlp_technologies/tree/huggingface-models/bert-base-uncased/vocab.txt"
def __init__(self, wordpiece_vocab, basic_tokenizer, wordpiece_tokenizer) -> None: self.vocab = wordpiece_vocab self.basic_tokenizer = basic_tokenizer self.wordpiece_tokenizer = wordpiece_tokenizer log_class_usage(__class__)
[docs] @classmethod def from_config(cls, config: Config): basic_tokenizer = create_component( ComponentType.TOKENIZER, config.basic_tokenizer ) vocab = WordPieceTokenizer.load_vocab(config.wordpiece_vocab_path) wordpiece_tokenizer = WordpieceTokenizer( vocab=vocab, unk_token="[UNK]" ) # UNK is for compatibility with HF v0.5 return cls(vocab, basic_tokenizer, wordpiece_tokenizer)
[docs] @staticmethod def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = OrderedDict() with PathManager.open(vocab_file, "r") as reader: tokens = reader.readlines() for index, token in enumerate(tokens): token = token.rstrip("\n") vocab[token] = index return vocab
[docs] def tokenize(self, input_str: str) -> List[Token]: tokens = [] for token in self.basic_tokenizer.tokenize(input_str): start = token.start for sub_token in self.wordpiece_tokenizer.tokenize(token.value): piece_len = ( len(sub_token) if not sub_token.startswith("##") else (len(sub_token) - 2) # account for ## ) if sub_token == "[UNK]": # this fixes the bug wherein piece_len = 5 for all [UNK] piece_len = len(token.value) end = start + piece_len tokens.append(Token(sub_token, start, end)) start = end return [token for token in tokens if token.value]
[docs]class PickleableGPT2BPEEncoder(GPT2BPEEncoder): """Fairseq's encoder stores the regex module as a local reference on its encoders, which means they can't be saved via pickle.dumps or torch.save. This modified their save/load logic doesn't store the module, and restores the reference after re-inflating.""" def __getstate__(self): # make a shallow copy of state to avoid side effect on the original object state = copy.copy(vars(self)) state.pop("re") return state def __setstate__(self, state): vars(self).update(state) import regex self.re = regex
[docs]class GPT2BPETokenizer(Tokenizer): """Tokenizer for gpt-2 and RoBERTa."""
[docs] class Config(ConfigBase): bpe_encoder_path: str = ( "manifold://pytext_training/tree/static/vocabs/bpe/gpt2/encoder.json" ) bpe_vocab_path: str = ( "manifold://pytext_training/tree/static/vocabs/bpe/gpt2/vocab.bpe" ) lowercase: bool = False
[docs] @classmethod def from_config(cls, config: Config): # TODO: T57433776 remove once FairSeq support PathManager config.bpe_encoder_path = PathManager.get_local_path(config.bpe_encoder_path) config.bpe_vocab_path = PathManager.get_local_path(config.bpe_vocab_path) bpe = create_gpt2_bpe(config.bpe_encoder_path, config.bpe_vocab_path) # This hacks the bpe instance to be picklable bpe = copy.copy(bpe) bpe.__class__ = PickleableGPT2BPEEncoder return cls(bpe, config.lowercase)
def __init__(self, bpe: GPT2BPEEncoder, lowercase: bool = False): self.bpe = bpe self.lowercase = lowercase log_class_usage(__class__)
[docs] def tokenize(self, input_str: str) -> List[Token]: if self.lowercase: bpe_ids = self.bpe.encode(input_str.lower()) else: bpe_ids = self.bpe.encode(input_str) char_tokens = [self.bpe.decoder[id].lstrip(u"\u0120") for id in bpe_ids] # fix for incorrect decoding of utf-8 chars for i, char_token in enumerate(char_tokens): try: char_tokens[i] = bytearray( [self.bpe.byte_decoder[char] for char in char_token] ).decode("utf-8") # handles BPE breaking a single multi-byte char into pieces except UnicodeDecodeError: continue lengths = [len(token) for token in char_tokens] tokens = [] end = 0 for length, id, char_token in zip(lengths, bpe_ids, char_tokens): start = input_str.find(char_token, end) end = start + length tokens.append(Token(str(id), start, end)) # handles bad start/end indices cascading to subsequent tokens. if len(tokens) > 1 and end < tokens[-2].end: end = tokens[-2].end return [token for token in tokens if token.value]
[docs] def decode(self, sentence: str): bpe_tokens = [] for i in sentence.split(): if i.isdigit(): bpe_tokens.append(int(i)) return self.bpe.decode(bpe_tokens)
[docs]class CppProcessorMixin: """Cpp processors like SentencePiece don't pickle well; reload them.""" def _load_processor(self): raise NotImplementedError def __getstate__(self): state = dict(vars(self)) state.pop("processor") return state def __setstate__(self, state): vars(self).update(state) self._load_processor()
[docs]class SentencePieceTokenizer(Tokenizer, CppProcessorMixin): """Sentence piece tokenizer."""
[docs] class Config(ConfigBase): sp_model_path: str = "" max_input_text_length: Optional[int] = None
def __init__( self, sp_model_path: str = "", max_input_text_length: Optional[int] = None ): self.sp_model_path = sp_model_path self.max_input_text_length = max_input_text_length self._load_processor() log_class_usage(__class__)
[docs] @classmethod def from_config(cls, config: Config): return cls(config.sp_model_path, config.max_input_text_length)
[docs] def tokenize(self, input_str: str) -> List[Token]: if ( hasattr(self, "max_input_text_length") and self.max_input_text_length is not None ): input_str = input_str[: self.max_input_text_length] pieces = self.processor.EncodeAsPieces(input_str) tokens = [] # calculate start and end indices of each piece. end = 0 for piece in pieces: original_piece = piece.lstrip("\u2581") start = input_str.find(original_piece, end) end = start + len(original_piece) tokens.append(Token(piece, start, end)) return tokens
def _load_processor(self): self.processor = SentencePieceProcessor() self.processor.Load(PathManager.get_local_path(self.sp_model_path))
[docs] def torchscriptify(self): return ScriptDoNothingTokenizer()