Source code for pytext.task.tasks

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, Optional, Union

from pytext.common.constants import Stage
from pytext.config import doc_classification as DocClassification
from pytext.config.field_config import WordLabelConfig
from pytext.data.bert_tensorizer import BERTTensorizer
from pytext.data.data import Data
from pytext.data.packed_lm_data import PackedLMData
from pytext.data.tensorizers import Tensorizer
from pytext.exporters import DenseFeatureExporter
from pytext.metric_reporters import (
    ClassificationMetricReporter,
    CompositionalMetricReporter,
    IntentSlotMetricReporter,
    LanguageModelMetricReporter,
    NERMetricReporter,
    PairwiseRankingMetricReporter,
    PureLossMetricReporter,
    RegressionMetricReporter,
    SequenceTaggingMetricReporter,
    SquadMetricReporter,
)
from pytext.metric_reporters.channel import ConsoleChannel
from pytext.metric_reporters.language_model_metric_reporter import (
    MaskedLMMetricReporter,
)
from pytext.models.bert_classification_models import NewBertModel
from pytext.models.bert_regression_model import NewBertRegressionModel
from pytext.models.doc_model import DocModel, DocRegressionModel
from pytext.models.ensembles import BaggingDocEnsembleModel, EnsembleModel
from pytext.models.joint_model import IntentSlotModel
from pytext.models.language_models.lmlstm import LMLSTM
from pytext.models.masked_lm import MaskedLanguageModel
from pytext.models.model import BaseModel
from pytext.models.pair_classification_model import BasePairwiseModel, PairwiseModel
from pytext.models.qna.bert_squad_qa import BertSquadQAModel
from pytext.models.qna.dr_qa import DrQAModel
from pytext.models.query_document_pairwise_ranking_model import (
    QueryDocPairwiseRankingModel,
)
from pytext.models.representations.sparse_transformer_sentence_encoder import (  # noqa f401
    SparseTransformerSentenceEncoder,
)
from pytext.models.roberta import RoBERTaWordTaggingModel
from pytext.models.semantic_parsers.rnng.rnng_parser import RNNGParser
from pytext.models.seq_models.contextual_intent_slot import (  # noqa
    ContextualIntentSlotModel,
)
from pytext.models.seq_models.seqnn import SeqNNModel, SeqNNModel_Deprecated
from pytext.models.word_model import WordTaggingModel
from pytext.task import Task_Deprecated
from pytext.task.new_task import NewTask
from pytext.trainers import EnsembleTrainer, HogwildTrainer, Trainer


[docs]class QueryDocumentPairwiseRankingTask(NewTask):
[docs] class Config(NewTask.Config): model: QueryDocPairwiseRankingModel.Config = ( QueryDocPairwiseRankingModel.Config() ) metric_reporter: PairwiseRankingMetricReporter.Config = ( PairwiseRankingMetricReporter.Config() )
[docs]class EnsembleTask(NewTask):
[docs] class Config(NewTask.Config): model: EnsembleModel.Config trainer: EnsembleTrainer.Config = EnsembleTrainer.Config() metric_reporter: Union[ ClassificationMetricReporter.Config, IntentSlotMetricReporter.Config ] = ClassificationMetricReporter.Config()
[docs] def train_single_model(self, train_config, model_id, rank=0, world_size=1): return self.trainer.real_trainers[model_id].train( self.data.batches(Stage.TRAIN), self.data.batches(Stage.EVAL), self.model.models[model_id], self.metric_reporter, train_config, )
[docs] @classmethod def example_config(cls): return cls.Config( model=BaggingDocEnsembleModel.Config(models=[DocModel.Config()]) )
[docs]class DocumentClassificationTask(NewTask):
[docs] class Config(NewTask.Config): model: BaseModel.Config = DocModel.Config() metric_reporter: Union[ ClassificationMetricReporter.Config, PureLossMetricReporter.Config ] = (ClassificationMetricReporter.Config())
# for multi-label classification task, # choose MultiLabelClassificationMetricReporter
[docs] @classmethod def format_prediction(cls, predictions, scores, context, target_names): for prediction, score in zip(predictions, scores): score_with_name = {n: s for n, s in zip(target_names, score.tolist())} yield { "prediction": target_names[prediction.data], "score": score_with_name, }
[docs]class DocumentRegressionTask(NewTask):
[docs] class Config(NewTask.Config): model: DocRegressionModel.Config = DocRegressionModel.Config() metric_reporter: RegressionMetricReporter.Config = ( RegressionMetricReporter.Config() )
[docs]class NewBertClassificationTask(DocumentClassificationTask):
[docs] class Config(DocumentClassificationTask.Config): model: NewBertModel.Config = NewBertModel.Config()
[docs]class NewBertPairClassificationTask(DocumentClassificationTask):
[docs] class Config(DocumentClassificationTask.Config): model: NewBertModel.Config = NewBertModel.Config( inputs=NewBertModel.Config.BertModelInput( tokens=BERTTensorizer.Config( columns=["text1", "text2"], max_seq_len=128 ) ) ) metric_reporter: ClassificationMetricReporter.Config = ( ClassificationMetricReporter.Config(text_column_names=["text1", "text2"]) )
[docs]class BertPairRegressionTask(DocumentRegressionTask):
[docs] class Config(DocumentRegressionTask.Config): model: NewBertRegressionModel.Config = NewBertRegressionModel.Config()
[docs]class WordTaggingTask(NewTask):
[docs] class Config(NewTask.Config): model: WordTaggingModel.Config = WordTaggingModel.Config() metric_reporter: SequenceTaggingMetricReporter.Config = ( SequenceTaggingMetricReporter.Config() )
[docs] @classmethod def create_metric_reporter(cls, config: Config, tensorizers: Dict[str, Tensorizer]): return SequenceTaggingMetricReporter.from_config( config.metric_reporter, tensorizers["labels"] )
[docs]class IntentSlotTask(NewTask):
[docs] class Config(NewTask.Config): model: IntentSlotModel.Config = IntentSlotModel.Config() metric_reporter: IntentSlotMetricReporter.Config = ( IntentSlotMetricReporter.Config() )
[docs]class LMTask(NewTask):
[docs] class Config(NewTask.Config): model: LMLSTM.Config = LMLSTM.Config() metric_reporter: LanguageModelMetricReporter.Config = ( LanguageModelMetricReporter.Config() )
[docs]class MaskedLMTask(NewTask):
[docs] class Config(NewTask.Config): data: Data.Config = PackedLMData.Config() model: MaskedLanguageModel.Config = MaskedLanguageModel.Config() metric_reporter: MaskedLMMetricReporter.Config = ( MaskedLMMetricReporter.Config() )
[docs]class PairwiseClassificationTask(NewTask):
[docs] class Config(NewTask.Config): model: BasePairwiseModel.Config = PairwiseModel.Config() metric_reporter: ClassificationMetricReporter.Config = ( ClassificationMetricReporter.Config(text_column_names=["text1", "text2"]) )
[docs]class RoBERTaNERTask(NewTask):
[docs] class Config(NewTask.Config): model: RoBERTaWordTaggingModel.Config = RoBERTaWordTaggingModel.Config() metric_reporter: NERMetricReporter.Config = NERMetricReporter.Config()
[docs] @classmethod def create_metric_reporter(cls, config: Config, tensorizers: Dict[str, Tensorizer]): return NERMetricReporter( channels=[ConsoleChannel()], label_names=list(tensorizers["tokens"].labels_vocab._vocab), pad_idx=tensorizers["tokens"].labels_pad_idx, )
[docs]class SeqNNTask(NewTask):
[docs] class Config(NewTask.Config): model: SeqNNModel.Config = SeqNNModel.Config() metric_reporter: ClassificationMetricReporter.Config = ( ClassificationMetricReporter.Config(text_column_names=["text_seq"]) )
[docs]class SquadQATask(NewTask):
[docs] class Config(NewTask.Config): model: Union[BertSquadQAModel.Config, DrQAModel.Config] = DrQAModel.Config() metric_reporter: SquadMetricReporter.Config = SquadMetricReporter.Config()
[docs]class SemanticParsingTask(NewTask):
[docs] class Config(NewTask.Config): model: RNNGParser.Config = RNNGParser.Config() trainer: HogwildTrainer.Config = HogwildTrainer.Config() metric_reporter: CompositionalMetricReporter.Config = ( CompositionalMetricReporter.Config() )
def __init__( self, data: Data, model: RNNGParser, metric_reporter: CompositionalMetricReporter, trainer: HogwildTrainer, ): super().__init__(data, model, metric_reporter, trainer) assert ( (data.batcher.train_batch_size == 1) and (data.batcher.eval_batch_size == 1) and (data.batcher.test_batch_size == 1) ), "RNNGParser only supports batch size = 1" assert trainer.config.report_train_metrics is False, ( "Disable report_train_metrics because trees are not necessarily " "valid during training" )