#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List, Optional
from pytext.common.constants import DatasetFieldName, Stage
from pytext.data.data_structures.annotation import CLOSE, OPEN, escape_brackets
from pytext.metrics.intent_slot_metrics import (
FramePredictionPair,
Node,
Span,
compute_all_metrics,
)
from pytext.utils.data import (
byte_length,
get_substring_from_offsets,
merge_token_labels_to_slot,
parse_slot_string,
)
from .channel import Channel, ConsoleChannel, FileChannel
from .metric_reporter import MetricReporter
DOC_LABEL_NAMES = "doc_label_names"
[docs]def create_frame(text, intent_label, slot_names_str, byte_len):
frame = Node(
label=intent_label,
span=Span(0, byte_len),
children={
Node(label=slot.label, span=Span(slot.start, slot.end))
for slot in parse_slot_string(slot_names_str)
},
text=text,
)
return frame
[docs]def frame_to_str(frame: Node):
annotation_str = OPEN + escape_brackets(frame.label) + " "
cur_index = 0
for slot in sorted(frame.children, key=lambda slot: slot.span.start):
annotation_str += escape_brackets(
get_substring_from_offsets(frame.text, cur_index, slot.span.start)
)
annotation_str += (
OPEN
+ escape_brackets(slot.label)
+ " "
+ escape_brackets(
get_substring_from_offsets(frame.text, slot.span.start, slot.span.end)
)
+ " "
+ CLOSE
)
cur_index = slot.span.end
annotation_str += (
escape_brackets(get_substring_from_offsets(frame.text, cur_index, None))
+ " "
+ CLOSE
)
return annotation_str
[docs]class IntentSlotMetricReporter(MetricReporter):
__EXPANSIBLE__ = True
def __init__(
self,
doc_label_names: List[str],
word_label_names: List[str],
use_bio_labels: bool,
channels: List[Channel],
slot_column_name: str = "slots",
text_column_name: str = "text",
token_tensorizer_name: str = "tokens",
) -> None:
super().__init__(channels)
self.doc_label_names = doc_label_names
self.word_label_names = word_label_names
self.use_bio_labels = use_bio_labels
self.slot_column_name = slot_column_name
self.text_column_name = text_column_name
self.token_tensorizer_name = token_tensorizer_name
[docs] class Config(MetricReporter.Config):
pass
[docs] @classmethod
def from_config(cls, config, tensorizers: Optional[Dict] = None):
# TODO this part should be handled more elegantly
for name in ["text_feats", "tokens"]:
if name in tensorizers:
token_tensorizer_name = name
break
return cls(
tensorizers["doc_labels"].vocab,
tensorizers["word_labels"].vocab,
getattr(tensorizers["word_labels"], "use_bio_labels", False),
[ConsoleChannel(), FileChannel((Stage.TEST,), config.output_path)],
tensorizers["word_labels"].slot_column,
tensorizers[token_tensorizer_name].text_column,
token_tensorizer_name,
)
[docs] def aggregate_preds(self, batch_preds, batch_context):
intent_preds, word_preds = batch_preds
self.all_preds.extend(
[
create_frame(
text,
self.doc_label_names[intent_pred],
merge_token_labels_to_slot(
token_range[0:seq_len],
[self.word_label_names[p] for p in word_pred[0:seq_len]],
self.use_bio_labels,
),
byte_length(text),
)
for text, intent_pred, word_pred, seq_len, token_range in zip(
batch_context[self.text_column_name],
intent_preds,
word_preds,
batch_context[DatasetFieldName.SEQ_LENS],
batch_context[DatasetFieldName.TOKEN_RANGE],
)
]
)
[docs] def aggregate_targets(self, batch_targets, batch_context):
intent_targets = batch_targets[0]
self.all_targets.extend(
[
create_frame(
text,
self.doc_label_names[intent_target],
raw_slot_label,
byte_length(text),
)
for text, intent_target, raw_slot_label, seq_len in zip(
batch_context[self.text_column_name],
intent_targets,
batch_context[DatasetFieldName.RAW_WORD_LABEL],
batch_context[DatasetFieldName.SEQ_LENS],
)
]
)
[docs] def get_raw_slot_str(self, raw_data_row):
return ",".join([str(x) for x in raw_data_row[self.slot_column_name]])
[docs] def aggregate_scores(self, batch_scores):
intent_scores, slot_scores = batch_scores
self.all_scores.extend(
(intent_score, slot_score)
for intent_score, slot_score in zip(
intent_scores.tolist(), slot_scores.tolist()
)
)
[docs] def predictions_to_report(self):
"""
Generate human readable predictions
"""
return [frame_to_str(frame) for frame in self.all_preds]
[docs] def targets_to_report(self):
"""
Generate human readable targets
"""
return [frame_to_str(frame) for frame in self.all_targets]
[docs] def calculate_metric(self):
return compute_all_metrics(
[
FramePredictionPair(pred_frame, target_frame)
for pred_frame, target_frame in zip(self.all_preds, self.all_targets)
],
frame_accuracy=True,
)
[docs] def batch_context(self, raw_batch, batch):
context = super().batch_context(raw_batch, batch)
context[self.text_column_name] = [
row[self.text_column_name] for row in raw_batch
]
context[DatasetFieldName.SEQ_LENS] = batch[self.token_tensorizer_name][
1
].tolist()
context[DatasetFieldName.TOKEN_RANGE] = batch[self.token_tensorizer_name][
2
].tolist()
context[DatasetFieldName.RAW_WORD_LABEL] = [
self.get_raw_slot_str(raw_data_row) for raw_data_row in raw_batch
]
return context
[docs] def get_model_select_metric(self, metrics):
return metrics.frame_accuracy