#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List, Optional
from pytext.common.constants import (
BatchContext,
DatasetFieldName,
RawExampleFieldName,
Stage,
)
from pytext.data.data_structures.annotation import INVALID_TREE_STR, Annotation
from pytext.data.tensorizers import Tensorizer
from pytext.metric_reporters.channel import ConsoleChannel
from pytext.metric_reporters.compositional_metric_reporter import (
CompositionalMetricReporter,
)
from pytext.metric_reporters.metric_reporter import MetricReporter
from pytext.metric_reporters.seq2seq_metric_reporter import (
Seq2SeqFileChannel,
Seq2SeqMetricReporter,
)
from pytext.metrics.intent_slot_metrics import FramePredictionPair, compute_all_metrics
from .seq2seq_utils import stringify
[docs]class CompositionalSeq2SeqFileChannel(Seq2SeqFileChannel):
def __init__(self, stages, file_path, tensorizers, accept_flat_intents_slots):
super().__init__(stages, file_path, tensorizers)
self.accept_flat_intents_slots = accept_flat_intents_slots
[docs] def get_title(self, context_keys=()):
return (
"row_index",
"text",
"predicted_output_sequence",
"prediction",
"target",
)
[docs] def validated_annotation(self, predicted_output_sequence):
try:
tree = Annotation(
predicted_output_sequence,
accept_flat_intents_slots=self.accept_flat_intents_slots,
).tree
except (ValueError, IndexError):
tree = Annotation(INVALID_TREE_STR).tree
return tree.flat_str()
[docs] def gen_content(self, metrics, loss, preds, targets, scores, context):
batch_size = len(targets)
for i in range(batch_size):
yield [
context[BatchContext.INDEX][i],
context[DatasetFieldName.RAW_SEQUENCE][i],
self.tensorizers["trg_seq_tokens"].stringify(preds[i][0]).upper(),
self.validated_annotation(
self.tensorizers["trg_seq_tokens"].stringify(preds[i][0]).upper()
),
self.validated_annotation(
self.tensorizers["trg_seq_tokens"].stringify(targets[i]).upper()
),
]
[docs]class Seq2SeqCompositionalMetricReporter(Seq2SeqMetricReporter):
def __init__(self, channels, log_gradient, tensorizers, accept_flat_intents_slots):
super().__init__(channels, log_gradient, tensorizers)
self.accept_flat_intents_slots = accept_flat_intents_slots
[docs] class Config(MetricReporter.Config):
accept_flat_intents_slots: Optional[bool] = False
[docs] @classmethod
def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
return cls(
[
ConsoleChannel(),
CompositionalSeq2SeqFileChannel(
[Stage.TEST],
config.output_path,
tensorizers,
config.accept_flat_intents_slots,
),
],
config.log_gradient,
tensorizers,
config.accept_flat_intents_slots,
)
def _reset(self):
super()._reset()
self.all_target_lens: List = []
self.all_src_tokens: List = []
self.all_target_trees: List = []
self.all_pred_trees: List = []
[docs] def calculate_metric(self):
all_metrics = compute_all_metrics(
self.create_frame_prediction_pairs(),
overall_metrics=True,
calculated_loss=self.calculate_loss(),
)
return all_metrics
[docs] def create_frame_prediction_pairs(self):
return [
FramePredictionPair(
CompositionalMetricReporter.tree_to_metric_node(pred_tree),
CompositionalMetricReporter.tree_to_metric_node(target_tree),
)
for pred_tree, target_tree in zip(
self.all_pred_trees, self.all_target_trees
)
]
[docs] def aggregate_targets(self, new_batch, context=None):
if new_batch is None:
return
target_vocab = self.tensorizers["trg_seq_tokens"].vocab
target_pad_token = target_vocab.get_pad_index()
target_bos_token = target_vocab.get_bos_index()
target_eos_token = target_vocab.get_eos_index()
cleaned_targets = [
self._remove_tokens(
target, [target_pad_token, target_eos_token, target_bos_token]
)
for target in self._make_simple_list(new_batch[0])
]
self.aggregate_data(self.all_targets, cleaned_targets)
self.aggregate_data(self.all_target_lens, new_batch[1])
target_trees = [
self.stringify_annotation_tree(target, target_vocab)
for target in cleaned_targets
]
self.aggregate_data(self.all_target_trees, target_trees)
[docs] def aggregate_preds(self, new_batch, context=None):
if new_batch is None:
return
target_vocab = self.tensorizers["trg_seq_tokens"].vocab
target_pad_token = target_vocab.get_pad_index()
target_bos_token = target_vocab.get_bos_index()
target_eos_token = target_vocab.get_eos_index()
cleaned_preds = [
self._remove_tokens(
pred, [target_pad_token, target_eos_token, target_bos_token]
)
for pred in self._make_simple_list(new_batch)
]
self.aggregate_data(self.all_preds, cleaned_preds)
pred_trees = [
self.stringify_annotation_tree(pred[0], target_vocab)
for pred in cleaned_preds
]
self.aggregate_data(self.all_pred_trees, pred_trees)
[docs] def stringify_annotation_tree(self, tree_tokens, tree_vocab):
stringified_tree_str = stringify(tree_tokens, tree_vocab._vocab)
return self.get_annotation_from_string(stringified_tree_str)
[docs] def get_annotation_from_string(self, stringified_tree_str: str) -> Annotation:
try:
tree = Annotation(
stringified_tree_str.upper(),
accept_flat_intents_slots=self.accept_flat_intents_slots,
).tree
except (ValueError, IndexError):
tree = Annotation(INVALID_TREE_STR).tree
return tree
[docs] def batch_context(self, raw_batch, batch):
return {
DatasetFieldName.RAW_SEQUENCE: [
row["source_sequence"] for row in raw_batch
],
BatchContext.INDEX: [
row[RawExampleFieldName.ROW_INDEX] for row in raw_batch
],
}