#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, Optional, Union
from pytext.common.constants import Stage
from pytext.config import doc_classification as DocClassification
from pytext.config.field_config import WordLabelConfig
from pytext.data.bert_tensorizer import BERTTensorizer
from pytext.data.data import Data
from pytext.data.packed_lm_data import PackedLMData
from pytext.data.tensorizers import Tensorizer
from pytext.exporters import DenseFeatureExporter
from pytext.metric_reporters import (
ClassificationMetricReporter,
CompositionalMetricReporter,
IntentSlotMetricReporter,
LanguageModelMetricReporter,
NERMetricReporter,
PairwiseRankingMetricReporter,
PureLossMetricReporter,
RegressionMetricReporter,
SequenceTaggingMetricReporter,
SquadMetricReporter,
)
from pytext.metric_reporters.channel import ConsoleChannel
from pytext.metric_reporters.language_model_metric_reporter import (
MaskedLMMetricReporter,
)
from pytext.models.bert_classification_models import NewBertModel
from pytext.models.bert_regression_model import NewBertRegressionModel
from pytext.models.doc_model import DocModel, DocRegressionModel
from pytext.models.ensembles import BaggingDocEnsembleModel, EnsembleModel
from pytext.models.joint_model import IntentSlotModel
from pytext.models.language_models.lmlstm import LMLSTM
from pytext.models.masked_lm import MaskedLanguageModel
from pytext.models.model import BaseModel
from pytext.models.pair_classification_model import BasePairwiseModel, PairwiseModel
from pytext.models.qna.bert_squad_qa import BertSquadQAModel
from pytext.models.qna.dr_qa import DrQAModel
from pytext.models.query_document_pairwise_ranking_model import (
QueryDocPairwiseRankingModel,
)
from pytext.models.representations.sparse_transformer_sentence_encoder import ( # noqa f401
SparseTransformerSentenceEncoder,
)
from pytext.models.roberta import RoBERTaWordTaggingModel
from pytext.models.semantic_parsers.rnng.rnng_parser import RNNGParser
from pytext.models.seq_models.contextual_intent_slot import ( # noqa
ContextualIntentSlotModel,
)
from pytext.models.seq_models.seqnn import SeqNNModel, SeqNNModel_Deprecated
from pytext.models.word_model import WordTaggingModel
from pytext.task import Task_Deprecated
from pytext.task.new_task import NewTask
from pytext.trainers import EnsembleTrainer, HogwildTrainer, Trainer
[docs]class QueryDocumentPairwiseRankingTask(NewTask):
[docs] class Config(NewTask.Config):
model: QueryDocPairwiseRankingModel.Config = (
QueryDocPairwiseRankingModel.Config()
)
metric_reporter: PairwiseRankingMetricReporter.Config = (
PairwiseRankingMetricReporter.Config()
)
[docs]class EnsembleTask(NewTask):
[docs] class Config(NewTask.Config):
model: EnsembleModel.Config
trainer: EnsembleTrainer.Config = EnsembleTrainer.Config()
metric_reporter: Union[
ClassificationMetricReporter.Config, IntentSlotMetricReporter.Config
] = ClassificationMetricReporter.Config()
[docs] def train_single_model(self, train_config, model_id, rank=0, world_size=1):
return self.trainer.real_trainers[model_id].train(
self.data.batches(Stage.TRAIN),
self.data.batches(Stage.EVAL),
self.model.models[model_id],
self.metric_reporter,
train_config,
)
[docs] @classmethod
def example_config(cls):
return cls.Config(
model=BaggingDocEnsembleModel.Config(models=[DocModel.Config()])
)
[docs]class DocumentClassificationTask(NewTask):
[docs] class Config(NewTask.Config):
model: BaseModel.Config = DocModel.Config()
metric_reporter: Union[
ClassificationMetricReporter.Config, PureLossMetricReporter.Config
] = (ClassificationMetricReporter.Config())
# for multi-label classification task,
# choose MultiLabelClassificationMetricReporter
[docs]class DocumentRegressionTask(NewTask):
[docs] class Config(NewTask.Config):
model: DocRegressionModel.Config = DocRegressionModel.Config()
metric_reporter: RegressionMetricReporter.Config = (
RegressionMetricReporter.Config()
)
[docs]class NewBertClassificationTask(DocumentClassificationTask):
[docs] class Config(DocumentClassificationTask.Config):
model: NewBertModel.Config = NewBertModel.Config()
[docs]class NewBertPairClassificationTask(DocumentClassificationTask):
[docs] class Config(DocumentClassificationTask.Config):
model: NewBertModel.Config = NewBertModel.Config(
inputs=NewBertModel.Config.BertModelInput(
tokens=BERTTensorizer.Config(
columns=["text1", "text2"], max_seq_len=128
)
)
)
metric_reporter: ClassificationMetricReporter.Config = (
ClassificationMetricReporter.Config(text_column_names=["text1", "text2"])
)
[docs]class BertPairRegressionTask(DocumentRegressionTask):
[docs] class Config(DocumentRegressionTask.Config):
model: NewBertRegressionModel.Config = NewBertRegressionModel.Config()
[docs]class WordTaggingTask(NewTask):
[docs] class Config(NewTask.Config):
model: WordTaggingModel.Config = WordTaggingModel.Config()
metric_reporter: SequenceTaggingMetricReporter.Config = (
SequenceTaggingMetricReporter.Config()
)
[docs] @classmethod
def create_metric_reporter(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
return SequenceTaggingMetricReporter.from_config(
config.metric_reporter, tensorizers["labels"]
)
[docs]class IntentSlotTask(NewTask):
[docs] class Config(NewTask.Config):
model: IntentSlotModel.Config = IntentSlotModel.Config()
metric_reporter: IntentSlotMetricReporter.Config = (
IntentSlotMetricReporter.Config()
)
[docs]class LMTask(NewTask):
[docs] class Config(NewTask.Config):
model: LMLSTM.Config = LMLSTM.Config()
metric_reporter: LanguageModelMetricReporter.Config = (
LanguageModelMetricReporter.Config()
)
[docs]class MaskedLMTask(NewTask):
[docs] class Config(NewTask.Config):
data: Data.Config = PackedLMData.Config()
model: MaskedLanguageModel.Config = MaskedLanguageModel.Config()
metric_reporter: MaskedLMMetricReporter.Config = (
MaskedLMMetricReporter.Config()
)
[docs]class PairwiseClassificationTask(NewTask):
[docs] class Config(NewTask.Config):
model: BasePairwiseModel.Config = PairwiseModel.Config()
metric_reporter: ClassificationMetricReporter.Config = (
ClassificationMetricReporter.Config(text_column_names=["text1", "text2"])
)
[docs]class RoBERTaNERTask(NewTask):
[docs] class Config(NewTask.Config):
model: RoBERTaWordTaggingModel.Config = RoBERTaWordTaggingModel.Config()
metric_reporter: NERMetricReporter.Config = NERMetricReporter.Config()
[docs] @classmethod
def create_metric_reporter(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
return NERMetricReporter(
channels=[ConsoleChannel()],
label_names=list(tensorizers["tokens"].labels_vocab._vocab),
pad_idx=tensorizers["tokens"].labels_pad_idx,
)
[docs]class SeqNNTask(NewTask):
[docs] class Config(NewTask.Config):
model: SeqNNModel.Config = SeqNNModel.Config()
metric_reporter: ClassificationMetricReporter.Config = (
ClassificationMetricReporter.Config(text_column_names=["text_seq"])
)
[docs]class SquadQATask(NewTask):
[docs] class Config(NewTask.Config):
model: Union[BertSquadQAModel.Config, DrQAModel.Config] = DrQAModel.Config()
metric_reporter: SquadMetricReporter.Config = SquadMetricReporter.Config()
[docs]class SemanticParsingTask(NewTask):
[docs] class Config(NewTask.Config):
model: RNNGParser.Config = RNNGParser.Config()
trainer: HogwildTrainer.Config = HogwildTrainer.Config()
metric_reporter: CompositionalMetricReporter.Config = (
CompositionalMetricReporter.Config()
)
def __init__(
self,
data: Data,
model: RNNGParser,
metric_reporter: CompositionalMetricReporter,
trainer: HogwildTrainer,
):
super().__init__(data, model, metric_reporter, trainer)
assert (
(data.batcher.train_batch_size == 1)
and (data.batcher.eval_batch_size == 1)
and (data.batcher.test_batch_size == 1)
), "RNNGParser only supports batch size = 1"
assert trainer.config.report_train_metrics is False, (
"Disable report_train_metrics because trees are not necessarily "
"valid during training"
)