#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import re
from typing import Any, Dict, Tuple, Union
import torch
from pytext.common.constants import Padding, VocabMeta
from pytext.config.field_config import EmbedInitStrategy
from pytext.utils import data as data_utils, precision
try:
from torchtext.legacy import data as textdata
except ImportError:
from torchtext import data as textdata
from torchtext.vocab import Vocab
[docs]def create_fields(fields_config, field_cls_dict):
fields_config_dict = fields_config._asdict()
return {
name: field_cls.from_config(fields_config_dict[name])
for name, field_cls in field_cls_dict.items()
if fields_config_dict[name]
}
[docs]def create_label_fields(label_configs, label_cls_dict):
if not isinstance(label_configs, list):
label_configs = [label_configs]
labels: Dict[str, Field] = {}
for label_config in label_configs:
labels[label_config._name] = label_cls_dict[label_config._name].from_config(
label_config
)
return labels
[docs]class Field(textdata.Field):
[docs] @classmethod
def from_config(cls, config):
print(f"creating field {cls.__name__}")
return cls(**config._asdict())
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs] def pad_length(self, n):
"""
Override to make pad_length to be multiple of 8 to support fp16 training
"""
return precision.pad_length(n)
[docs]class NestedField(Field, textdata.NestedField):
[docs]class RawField(textdata.RawField):
def __init__(self, *args, is_target=False, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.is_target = is_target
[docs]class VocabUsingField(Field):
"""Base class for all fields that need to build a vocabulary."""
def __init__(
self,
pretrained_embeddings_path="",
embed_dim=0,
embedding_init_strategy=EmbedInitStrategy.RANDOM,
vocab_file="",
vocab_size="",
vocab_from_train_data=True, # add tokens from train data to vocab
vocab_from_all_data=False, # add tokens from train, eval, test data to vocab
# add tokens from pretrained embeddings to vocab
vocab_from_pretrained_embeddings=False,
min_freq=1,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.pretrained_embeddings_path = pretrained_embeddings_path
self.vocab_file = vocab_file
self.vocab_size = vocab_size
self.vocab_from_train_data = vocab_from_train_data
self.vocab_from_all_data = vocab_from_all_data
self.vocab_from_pretrained_embeddings = vocab_from_pretrained_embeddings
self.min_freq = min_freq
self.embed_dim = embed_dim
self.embedding_init_strategy = embedding_init_strategy
[docs]class VocabUsingNestedField(VocabUsingField, NestedField):
"""Base class for all nested fields that need to build a vocabulary."""
pass
[docs]class DocLabelField(Field):
def __init__(self, **kwargs) -> None:
super().__init__(
sequential=False,
batch_first=True,
tokenize=data_utils.no_tokenize,
unk_token=None, # Don't include unk in the list of labels
)
[docs]class WordLabelField(Field):
def __init__(self, use_bio_labels, **kwargs):
super().__init__(
sequential=True,
batch_first=True,
tokenize=data_utils.simple_tokenize,
pad_token=Padding.WORD_LABEL_PAD,
unk_token=None, # Don't include unk in the list of labels
)
self.use_bio_labels = use_bio_labels
[docs]class TextFeatureField(VocabUsingField):
dummy_model_input = torch.tensor([[1], [1]], dtype=torch.long, device="cpu")
def __init__(
self,
pretrained_embeddings_path="",
embed_dim=0,
embedding_init_strategy=EmbedInitStrategy.RANDOM,
vocab_file="",
vocab_size="",
vocab_from_train_data=True,
vocab_from_all_data=False,
vocab_from_pretrained_embeddings=False,
postprocessing=None,
use_vocab=True,
include_lengths=True,
batch_first=True,
sequential=True,
pad_token=VocabMeta.PAD_TOKEN,
unk_token=VocabMeta.UNK_TOKEN,
init_token=None,
eos_token=None,
lower=False,
tokenize=data_utils.no_tokenize,
fix_length=None,
pad_first=None,
min_freq=1,
**kwargs,
) -> None:
super().__init__(
pretrained_embeddings_path=pretrained_embeddings_path,
embed_dim=embed_dim,
embedding_init_strategy=embedding_init_strategy,
vocab_file=vocab_file,
vocab_size=vocab_size,
vocab_from_train_data=vocab_from_train_data,
vocab_from_all_data=vocab_from_all_data,
vocab_from_pretrained_embeddings=vocab_from_pretrained_embeddings,
postprocessing=postprocessing,
use_vocab=use_vocab,
include_lengths=include_lengths,
batch_first=batch_first,
sequential=sequential,
pad_token=pad_token,
unk_token=unk_token,
init_token=init_token,
eos_token=eos_token,
lower=lower,
tokenize=tokenize,
fix_length=fix_length,
pad_first=pad_first,
min_freq=min_freq,
)
[docs]class SeqFeatureField(VocabUsingNestedField):
dummy_model_input = torch.tensor([[[1]], [[1]]], dtype=torch.long, device="cpu")
def __init__(
self,
pretrained_embeddings_path="",
embed_dim=0,
embedding_init_strategy=EmbedInitStrategy.RANDOM,
vocab_file="",
vocab_size="",
vocab_from_train_data=True,
vocab_from_all_data=False,
vocab_from_pretrained_embeddings=False,
postprocessing=None,
use_vocab=True,
include_lengths=True,
pad_token=VocabMeta.PAD_SEQ,
init_token=None,
eos_token=None,
tokenize=data_utils.no_tokenize,
nesting_field=None,
**kwargs,
):
super().__init__(
pretrained_embeddings_path=pretrained_embeddings_path,
embed_dim=embed_dim,
embedding_init_strategy=embedding_init_strategy,
vocab_file=vocab_file,
vocab_size=vocab_size,
vocab_from_train_data=vocab_from_train_data,
vocab_from_all_data=vocab_from_all_data,
vocab_from_pretrained_embeddings=vocab_from_pretrained_embeddings,
postprocessing=postprocessing,
use_vocab=use_vocab,
include_lengths=include_lengths,
pad_token=pad_token,
init_token=init_token,
eos_token=eos_token,
tokenize=tokenize,
nesting_field=nesting_field
if nesting_field is not None
else TextFeatureField(include_lengths=False, lower=False),
)
[docs]class FloatField(Field):
def __init__(self, **kwargs):
super().__init__(
sequential=False,
use_vocab=False,
batch_first=True,
tokenize=data_utils.no_tokenize,
dtype=torch.float,
unk_token=None,
)
[docs]class FloatVectorField(Field):
def __init__(self, dim=0, dim_error_check=False, **kwargs):
super().__init__(
use_vocab=False,
batch_first=True,
tokenize=self._parse_vector,
dtype=torch.float,
preprocessing=textdata.Pipeline(
float
), # Convert each string to float. float() takes care of whitespace.
fix_length=dim,
pad_token=0, # For irregular sized vectors, pad the missing units with 0s.
)
self.dim_error_check = dim_error_check # dims in data should match config
self.dummy_model_input = torch.tensor(
[[1.0] * dim, [1.0] * dim], dtype=torch.float, device="cpu"
)
def _parse_vector(self, s):
if not s:
return None
if s[0] == "[":
s = s[1:] # Remove '['
if s[-1] == "]":
s = s[:-1] # Remove ']'
s_array = re.split("[, ]", s) # Split on ',' or ' '
ret_val = [s for s in s_array if s] # Filter non strings
assert not self.dim_error_check or len(ret_val) == self.fix_length, (
"FloatVectorField has mismatched dimensions. Config dims:"
+ str(self.fix_length)
+ " data:"
+ str(len(ret_val))
+ ", actual vector:"
+ str(ret_val)
)
return ret_val
[docs]class ActionField(VocabUsingField):
def __init__(self, **kwargs):
super().__init__(
use_vocab=True,
sequential=True,
batch_first=True,
tokenize=data_utils.no_tokenize,
unk_token=None, # Don't include UNK in the list of labels
pad_token=None, # Don't include PAD in the list of labels
vocab_from_all_data=True, # Actions must be built from all data.
)