Source code for pytext.data.masked_util

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List, Optional, Set

import numpy as np
from pytext.common.constants import SpecialTokens, Token
from pytext.config.component import Component, ComponentType
from pytext.config.pytext_config import ConfigBase
from pytext.data.data_structures.annotation import Annotation, Intent, Root, Slot
from pytext.data.utils import VocabBuilder, Vocabulary


[docs]class MaskedVocabBuilder(VocabBuilder): def __init__(self, delimiter=" "): super().__init__(delimiter) self.use_mask = True
SPECIAL_TOKENS: Dict[str, Token] = { str(SpecialTokens.MASK): SpecialTokens.MASK, str(SpecialTokens.BOS): SpecialTokens.BOS, str(SpecialTokens.EOS): SpecialTokens.EOS, }
[docs]class MaskingFunction(Component): class Config(ConfigBase): pass __EXPANSIBLE__ = True __COMPONENT_TYPE__ = ComponentType.MASKING_FUNCTION
[docs] @classmethod def from_config(cls, config, use_bos, use_eos): return cls(use_bos, use_eos)
def __init__(self, use_bos, use_eos): self.use_bos = use_bos self.use_eos = use_eos
[docs] def should_mask(self, *args, **kwargs) -> bool: return True
[docs] def gen_masked_source_target(self, tokens, *args, **kwargs): raise NotImplementedError()
def _prepare_dec_target( self, dec_source: List[int], clean_input_tokens: List[int], vocab: Vocabulary ) -> List[int]: dec_target = [ vocab.get_pad_index() if dec_source_token != vocab.get_mask_index() else dec_real_target_token for (dec_source_token, dec_real_target_token) in zip( dec_source, clean_input_tokens ) ] return dec_target
[docs]class TreeMask(MaskingFunction): class Config(ConfigBase): accept_flat_intents_slots: bool = True factor: int = 2
[docs] @classmethod def from_config(cls, config, use_bos, use_eos): return cls(config.accept_flat_intents_slots, config.factor, use_bos, use_eos)
def __init__(self, accept_flat_intents_slots, factor, use_bos, use_eos): super().__init__(use_bos, use_eos) self.accept_flat_intents_slots = accept_flat_intents_slots self.factor = factor
[docs] def clean_eos_bos(self, tokens): start_index, end_index = 0, len(tokens) if self.use_bos: start_index = 1 if self.use_eos: end_index = -1 return tokens[start_index:end_index]
[docs] def gen_masked_tree(self, node, mask_token, depth=1): if self.should_mask(depth): actual_str_len = len(node.flat_str().strip().split(" ")) return " ".join([mask_token for idx in range(actual_str_len)]) else: return_str = " " if ( isinstance(node, Intent) or isinstance(node, Slot) or isinstance(node, Root) ): return_str += "[" return_str += node.label return_str += " " for child in node.children: return_str += self.gen_masked_tree(child, mask_token, depth + 1) return_str += " " return_str += "]" else: return_str += node.label return_str += " " return return_str.strip()
[docs] def should_mask(self, depth=1): return np.random.random() < 1.0 / (self.factor ** depth)
[docs] def gen_masked_source_target(self, tokens: List[int], vocab: Vocabulary): cleaned_tokens = self.clean_eos_bos(tokens) original_target_string = " ".join( [vocab[idx] for idx in cleaned_tokens] ).upper() try: annotation = Annotation( original_target_string, accept_flat_intents_slots=self.accept_flat_intents_slots, ) except Exception as e: # This should never happen other than when testing print(e, original_target_string) dec_source = [vocab.idx[vocab.mask_token] for _ in range(len(tokens))] dec_target = [vocab.idx[vocab.pad_token] for _ in range(len(tokens))] return dec_source, dec_target assert len(annotation.root.children) == 1 mask_tree_str = self.gen_masked_tree( annotation.root.children[0], vocab.mask_token ) # We are calling the .split() instead of the tokenize() of tensorizer # because the input str contains special MASK token __MASK__ # It we call tokenize() on this input_str, it may lower __MASK__ or split # in unexpected ways causing issues. # Hence temporary workaround is that we call split(" ") and lower all tokens # other than MASK tokens # handle special tokens in vocab mask_tree_str: List[str] = list( map( lambda token: SPECIAL_TOKENS.get(token, token.lower()), mask_tree_str.split(" "), ) ) dec_source = [vocab.idx.get(t) for t in mask_tree_str] dec_target = self._prepare_dec_target(dec_source, cleaned_tokens, vocab) if self.use_bos: if self.should_mask(): dec_source.insert(0, vocab.get_mask_index()) dec_target.insert(0, vocab.get_bos_index()) else: dec_source.insert(0, vocab.get_bos_index()) dec_target.insert(0, vocab.get_pad_index()) if self.use_eos: if self.should_mask(): dec_source.append(vocab.get_mask_index()) dec_target.append(vocab.get_eos_index()) else: dec_source.append(vocab.get_eos_index()) dec_target.append(vocab.get_pad_index()) return dec_source, dec_target
[docs]class MaskEverything(MaskingFunction):
[docs] def gen_masked_tree(self, node, mask_token, depth=1): actual_str_len = len(node.flat_str().strip().split(" ")) return " ".join([mask_token for idx in range(actual_str_len)])
[docs] def gen_masked_source_target(self, tokens, vocab: Vocabulary): dec_source: List[int] = [vocab.get_mask_index() for idx in tokens] dec_target = self._prepare_dec_target(dec_source, tokens, vocab) return dec_source, dec_target
[docs]class RandomizedMaskingFunction(MaskingFunction): class Config(MaskingFunction.Config): seed: Optional[int] = None minimum_masks: int = 1
[docs] @classmethod def from_config(cls, config: Config, use_bos: bool, use_eos: bool): return cls(config.seed, config.minimum_masks, use_bos, use_eos)
def __init__( self, seed: Optional[int], minimum_masks: int, use_bos: bool, use_eos: bool ): super().__init__(use_bos, use_eos) self.random = np.random.RandomState(seed) self.minimum_masks = minimum_masks
[docs] def gen_masked_source_target(self, tokens: List[int], vocab: Vocabulary): num_masks = self.random.randint(self.minimum_masks, len(tokens)) ind: Set[int] = set( self.random.choice(len(tokens), size=num_masks, replace=False) ) dec_source: List[int] = [ vocab.get_mask_index() if idx in ind else token for idx, token in enumerate(tokens) ] dec_target = self._prepare_dec_target(dec_source, tokens, vocab) return dec_source, dec_target
[docs]class NoOpMaskingFunction(MaskingFunction): class Config(MaskingFunction.Config): seed: Optional[int] = None minimum_masks: int = 1
[docs] @classmethod def from_config(cls, config: Config, use_bos: bool, use_eos: bool): return cls(config.seed, config.minimum_masks, use_bos, use_eos)
def __init__( self, seed: Optional[int], minimum_masks: int, use_bos: bool, use_eos: bool ): super().__init__(use_bos, use_eos) self.random = np.random.RandomState(seed) self.minimum_masks = minimum_masks
[docs] def gen_masked_source_target(self, tokens: List[int], vocab: Vocabulary): dec_target = self._prepare_dec_target(tokens, tokens, vocab) return tokens, dec_target