Source code for pytext.data.test.mask_tensorizers_test

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import types
import unittest
from typing import List

from pytext.data.masked_tensorizer import MaskedTokenTensorizer
from pytext.data.masked_util import (
    MaskEverything,
    RandomizedMaskingFunction,
    TreeMask,
)
from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import TSVDataSource
from pytext.utils.test import import_tests_module


tests_module = import_tests_module()


[docs]class MaskTensorizersTest(unittest.TestCase):
[docs] def setUp(self): self.data = TSVDataSource( SafeFileWrapper(tests_module.test_file("compositional_seq2seq_unit.tsv")), test_file=None, eval_file=None, field_names=["text", "seqlogical"], schema={"text": str, "seqlogical": str}, ) self.masked_tensorizer = MaskedTokenTensorizer.from_config( MaskedTokenTensorizer.Config( column="seqlogical", masking_function=TreeMask.Config() ) ) self._initialize_tensorizer(self.masked_tensorizer)
def _initialize_tensorizer(self, tensorizer, data=None): if data is None: data = self.data init = tensorizer.initialize() init.send(None) # kick for row in data.train: init.send(row) init.close()
[docs] def test_basic_tree_masking(self): rows = [ { "text": "delays in tempe", "seqlogical": "[in:get_info_traffic delays in [sl:location tempe ] ]", }, { "text": "find me the quickest route home", "seqlogical": "[in:get_directions find me the quickest route [sl:destination [in:get_location_home home ] ] ]", }, ] vocab = self.masked_tensorizer.vocab masked_results = self.masked_tensorizer.tensorize( [self.masked_tensorizer.numberize(row) for row in rows] ) all_tokens, _, _, all_masked_source, all_masked_target = masked_results for tokens, masked_source, masked_target in zip( all_tokens, all_masked_source, all_masked_target ): assert len(masked_source) == len(masked_target) assert len(tokens) == len(masked_target) for i in range(len(masked_source)): # For masked tokens, dec_target is real target tokens if masked_source[i] == vocab.get_mask_index(): assert masked_target[i] == tokens[i], ( str(masked_target[i]) + " != " + str(tokens[i]) ) # For unmasked, target is pad token elif masked_source[i] != vocab.get_mask_index(): assert masked_target[i] == vocab.get_pad_index()
[docs] def test_mask_at_depth_k(self): rows = [ { "text": "find me the quickest route home", "seqlogical": "[in:get_directions find me the quickest route [sl:destination [in:get_location_home home ] ] ]", } ] vocab = self.masked_tensorizer.vocab def should_mask(self, depth=1): if depth == 3: return True else: return False self.masked_tensorizer.mask.should_mask = types.MethodType( should_mask, self.masked_tensorizer.mask ) masked_results = self.masked_tensorizer.tensorize( [self.masked_tensorizer.numberize(row) for row in rows] ) all_tokens, _, _, all_masked_source, _all_masked_target = masked_results _, masked_source = (all_tokens[0], all_masked_source[0]) masked_source_tokens: List[str] = [vocab[tok] for tok in masked_source] self.assertEqual( masked_source_tokens, [ "[in:get_directions", "find", "me", "the", "quickest", "route", "[sl:destination", vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], "]", "]", ], )
[docs] def test_tree_mask_with_bos_eos(self): rows = [ { "text": "find me the quickest route home", "seqlogical": "[in:get_directions find me the quickest route [sl:destination [in:get_location_home home ] ] ]", } ] masked_tensorizer = MaskedTokenTensorizer.from_config( MaskedTokenTensorizer.Config( column="seqlogical", masking_function=TreeMask.Config(), add_bos_token=True, add_eos_token=True, ) ) self._initialize_tensorizer(masked_tensorizer) vocab = masked_tensorizer.vocab def should_mask(self, depth=1): if depth == 3: return True else: return False masked_tensorizer.mask.should_mask = types.MethodType( should_mask, masked_tensorizer.mask ) masked_results = masked_tensorizer.tensorize( [masked_tensorizer.numberize(row) for row in rows] ) all_tokens, _, _, all_masked_source, _all_masked_target = masked_results _, masked_source = (all_tokens[0], all_masked_source[0]) masked_source_tokens: List[str] = [vocab[tok] for tok in masked_source] self.assertEqual( masked_source_tokens, [ vocab.bos_token, "[in:get_directions", "find", "me", "the", "quickest", "route", "[sl:destination", vocab.mask_token, vocab.mask_token, vocab.mask_token, "]", "]", vocab.eos_token, ], )
[docs] def test_mask_all(self): rows = [ { "text": "find me the quickest route home", "seqlogical": "[in:get_directions find me the quickest route [sl:destination [in:get_location_home home ] ] ]", } ] masked_tensorizer = MaskedTokenTensorizer.from_config( MaskedTokenTensorizer.Config( column="seqlogical", masking_function=MaskEverything.Config() ) ) self._initialize_tensorizer(masked_tensorizer) vocab = masked_tensorizer.vocab masked_results = masked_tensorizer.tensorize( [masked_tensorizer.numberize(row) for row in rows] ) all_tokens, _, _, all_masked_source, _all_masked_target = masked_results _, masked_source = (all_tokens[0], all_masked_source[0]) masked_tokens: List[str] = [vocab[tok] for tok in masked_source] self.assertEqual( masked_tokens, [vocab[vocab.get_mask_index()]] * len(masked_tokens) )
[docs] def test_mask_random(self): rows = [ { "text": "find me the quickest route home", "seqlogical": "[in:get_directions find me the quickest route [sl:destination [in:get_location_home home ] ] ]", } ] masked_tensorizer = MaskedTokenTensorizer.from_config( MaskedTokenTensorizer.Config( column="seqlogical", masking_function=RandomizedMaskingFunction.Config(seed=2), ) ) self._initialize_tensorizer(masked_tensorizer) vocab = masked_tensorizer.vocab masked_results = masked_tensorizer.tensorize( [masked_tensorizer.numberize(row) for row in rows] ) all_tokens, _, _, all_masked_source, _all_masked_target = masked_results _, masked_source = (all_tokens[0], all_masked_source[0]) masked_tokens: List[str] = [vocab[tok] for tok in masked_source] target: List[str] = [ vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], "me", vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], "[sl:destination", vocab[vocab.get_mask_index()], "home", vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], ] self.assertEqual(masked_tokens, target)
[docs] def test_mask_no_op(self): rows = [ { "text": "find me the quickest route home", "seqlogical": "[in:get_directions find me the quickest route [sl:destination [in:get_location_home home ] ] ]", } ] masked_tensorizer = MaskedTokenTensorizer.from_config( MaskedTokenTensorizer.Config( column="seqlogical", masking_function=RandomizedMaskingFunction.Config(seed=2), ) ) self._initialize_tensorizer(masked_tensorizer) vocab = masked_tensorizer.vocab masked_results = masked_tensorizer.tensorize( [masked_tensorizer.numberize(row) for row in rows] ) all_tokens, _, _, all_masked_source, _all_masked_target = masked_results _, masked_source = (all_tokens[0], all_masked_source[0]) masked_tokens: List[str] = [vocab[tok] for tok in masked_source] target: List[str] = [ vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], "me", vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], "[sl:destination", vocab[vocab.get_mask_index()], "home", vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], vocab[vocab.get_mask_index()], ] self.assertEqual(masked_tokens, target)