#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from collections import Counter as counter, defaultdict
from copy import deepcopy
from typing import (
AbstractSet,
Any,
Callable,
Counter,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
)
from pytext.data.data_structures.node import Node as NodeBase, Span
from . import (
AllConfusions,
Confusions,
PerLabelConfusions,
PRF1Metrics,
PRF1Scores,
safe_division,
)
"""
Metric classes and functions for intent-slot prediction problems.
"""
[docs]class Node(NodeBase):
"""
Subclass of the base Node class, used for metric purposes. It is immutable so that
hashing can be done on the class.
Attributes:
label (str): Label of the node.
span (Span): Span of the node.
children (:obj:`frozenset` of :obj:`Node`): frozenset of the node's children,
left empty when computing bracketing metrics.
text (str): Text the node covers (=utterance[span.start:span.end])
"""
def __init__(
self,
label: str,
span: Span,
children: Optional[AbstractSet["Node"]] = None,
text: str = None,
) -> None:
super().__init__(
label, span, frozenset(children) if children else frozenset(), text
)
def __setattr__(self, name: str, value: Any) -> None:
raise AttributeError("Node class is immutable.")
def __hash__(self):
return hash((self.label, self.span))
[docs]class FramePredictionPair(NamedTuple):
"""
Pair of predicted and gold intent frames.
"""
predicted_frame: Node
expected_frame: Node
[docs]class NodesPredictionPair(NamedTuple):
"""
Pair of predicted and expected sets of nodes.
"""
predicted_nodes: Counter[Node]
expected_nodes: Counter[Node]
[docs]class IntentsAndSlots(NamedTuple):
"""
Collection of intents and slots in an intent frame.
"""
intents: Counter[Node]
slots: Counter[Node]
[docs]class FrameAccuracy(NamedTuple):
"""
Frame accuracy for a collection of intent frame predictions.
Frame accuracy means the entire tree structure of the predicted frame matches that
of the gold frame.
"""
num_samples: int
frame_accuracy: float
FrameAccuraciesByDepth = Dict[int, FrameAccuracy]
"""
Frame accuracies bucketized by depth of the gold tree.
"""
[docs]class IntentSlotMetrics(NamedTuple):
"""
Precision/recall/F1 metrics for intents and slots.
Attributes:
intent_metrics: Precision/recall/F1 metrics for intents.
slot_metrics: Precision/recall/F1 metrics for slots.
overall_metrics: Combined precision/recall/F1 metrics for all nodes (merging
intents and slots).
"""
intent_metrics: Optional[PRF1Metrics]
slot_metrics: Optional[PRF1Metrics]
overall_metrics: Optional[PRF1Scores]
[docs] def print_metrics(self) -> None:
if self.intent_metrics:
print("\nIntent Metrics")
self.intent_metrics.print_metrics()
if self.slot_metrics:
print("\nSlot Metrics")
self.slot_metrics.print_metrics()
if self.overall_metrics:
print("\nMerged Intent and Slot Metrics")
print(
f" P = {self.overall_metrics.precision * 100:.2f} "
f"R = {self.overall_metrics.recall * 100:.2f}, "
f"F1 = {self.overall_metrics.f1 * 100:.2f}."
)
[docs]class AllMetrics(NamedTuple):
"""
Aggregated class for intent-slot related metrics.
Attributes:
top_intent_accuracy: Accuracy of the top-level intent.
frame_accuracy: Frame accuracy.
frame_accuracies_by_depth: Frame accuracies bucketized by depth of the gold
tree.
bracket_metrics: Bracket metrics for intents and slots. For details, see the
function `compute_intent_slot_metrics()`.
tree_metrics: Tree metrics for intents and slots. For details, see the function
`compute_intent_slot_metrics()`.
loss: Cross entropy loss.
"""
top_intent_accuracy: Optional[float]
frame_accuracy: Optional[float]
frame_accuracy_top_k: Optional[float]
frame_accuracies_by_depth: Optional[FrameAccuraciesByDepth]
bracket_metrics: Optional[IntentSlotMetrics]
tree_metrics: Optional[IntentSlotMetrics]
loss: Optional[float] = None
[docs] def print_metrics(self) -> None:
if self.frame_accuracy:
print(f"\n\nFrame accuracy = {self.frame_accuracy * 100:.2f}")
if self.frame_accuracy_top_k:
print(f"\n\nTop k frame accuracy = {self.frame_accuracy_top_k * 100:.2f}")
if self.bracket_metrics:
print("\n\nBracket Metrics")
self.bracket_metrics.print_metrics()
if self.tree_metrics:
print("\n\nTree Metrics")
self.tree_metrics.print_metrics()
[docs]class IntentSlotConfusions(NamedTuple):
"""
Aggregated class for intent and slot confusions.
Attributes:
intent_confusions: Confusion counts for intents.
slot_confusions: Confusion counts for slots.
"""
intent_confusions: Confusions
slot_confusions: Confusions
def _compare_nodes(
predicted_nodes: Counter[Node],
expected_nodes: Counter[Node],
per_label_confusions: Optional[PerLabelConfusions] = None,
) -> Confusions:
true_positives = predicted_nodes & expected_nodes
false_positives = predicted_nodes - true_positives
false_negatives = expected_nodes - true_positives
if per_label_confusions:
for node, count in true_positives.items():
per_label_confusions.update(node.label, "TP", count)
for node, count in false_positives.items():
per_label_confusions.update(node.label, "FP", count)
for node, count in false_negatives.items():
per_label_confusions.update(node.label, "FN", count)
return Confusions(
TP=sum(true_positives.values()),
FP=sum(false_positives.values()),
FN=sum(false_negatives.values()),
)
def _get_intents_and_slots(frame: Node, tree_based: bool) -> IntentsAndSlots:
intents: Counter[Node] = counter()
slots: Counter[Node] = counter()
def process_node(node: Node, is_intent: bool) -> None:
for child in node.children:
process_node(child, not is_intent)
if not tree_based:
node = type(node)(node.label, deepcopy(node.span), text=node.text)
if is_intent:
intents[node] += 1
else:
slots[node] += 1
process_node(frame, True)
return IntentsAndSlots(intents=intents, slots=slots)
[docs]def compare_frames(
predicted_frame: Node,
expected_frame: Node,
tree_based: bool,
intent_per_label_confusions: Optional[PerLabelConfusions] = None,
slot_per_label_confusions: Optional[PerLabelConfusions] = None,
) -> IntentSlotConfusions:
"""
Compares two intent frames and returns TP, FP, FN counts for intents and slots.
Optionally collects the per label TP, FP, FN counts.
Args:
predicted_frame: Predicted intent frame.
expected_frame: Gold intent frame.
tree_based: Whether to get the tree-based confusions (if True) or bracket-based
confusions (if False). For details, see the function
`compute_intent_slot_metrics()`.
intent_per_label_confusions: If provided, update the per label confusions for
intents as well. Defaults to None.
slot_per_label_confusions: If provided, update the per label confusions for
slots as well. Defaults to None.
Returns:
IntentSlotConfusions, containing confusion counts for intents and slots.
"""
predicted_intents_and_slots = _get_intents_and_slots(
predicted_frame, tree_based=tree_based
)
expected_intents_and_slots = _get_intents_and_slots(
expected_frame, tree_based=tree_based
)
return IntentSlotConfusions(
intent_confusions=_compare_nodes(
predicted_intents_and_slots.intents,
expected_intents_and_slots.intents,
intent_per_label_confusions,
),
slot_confusions=_compare_nodes(
predicted_intents_and_slots.slots,
expected_intents_and_slots.slots,
slot_per_label_confusions,
),
)
[docs]def compute_prf1_metrics(
nodes_pairs: Sequence[NodesPredictionPair],
) -> Tuple[AllConfusions, PRF1Metrics]:
"""
Computes precision/recall/F1 metrics given a list of predicted and expected sets of
nodes.
Args:
nodes_pairs: List of predicted and expected node sets.
Returns:
A tuple, of which the first member contains the confusion information, and the
second member contains the computed precision/recall/F1 metrics.
"""
all_confusions = AllConfusions()
for (predicted_nodes, expected_nodes) in nodes_pairs:
all_confusions.confusions += _compare_nodes(
predicted_nodes, expected_nodes, all_confusions.per_label_confusions
)
return all_confusions, all_confusions.compute_metrics()
[docs]def compute_intent_slot_metrics(
frame_pairs: Sequence[FramePredictionPair],
tree_based: bool,
overall_metrics: bool = True,
) -> IntentSlotMetrics:
"""
Given a list of predicted and gold intent frames, computes precision, recall and F1
metrics for intents and slots, either in tree-based or bracket-based manner.
The following assumptions are taken on intent frames:
1. The root node is an intent,
2. Children of intents are always slots, and children of slots are always intents.
For tree-based metrics, a node (an intent or slot) in the predicted frame is
considered a true positive only if the subtree rooted at this node has an exact copy
in the gold frame, otherwise it is considered a false positive. A false negative is
a node in the gold frame that does not have an exact subtree match in the predicted
frame.
For bracket-based metrics, a node in the predicted frame is considered a true
positive if there is a node in the gold frame having the same label and span (but
not necessarily the same children). The definitions of false positives and false
negatives are similar to the above.
Args:
frame_pairs: List of predicted and gold intent frames.
tree_based: Whether to compute tree-based metrics (if True) or bracket-based
metrics (if False).
overall_metrics: Whether to compute overall (merging intents and slots) metrics
or not. Defaults to True.
Returns:
IntentSlotMetrics, containing precision/recall/F1 metrics for intents and slots.
"""
intents_pairs: List[NodesPredictionPair] = []
slots_pairs: List[NodesPredictionPair] = []
for (predicted_frame, expected_frame) in frame_pairs:
predicted = _get_intents_and_slots(predicted_frame, tree_based=tree_based)
expected = _get_intents_and_slots(expected_frame, tree_based=tree_based)
intents_pairs.append(NodesPredictionPair(predicted.intents, expected.intents))
slots_pairs.append(NodesPredictionPair(predicted.slots, expected.slots))
intent_confusions, intent_metrics = compute_prf1_metrics(intents_pairs)
slot_confusions, slot_metrics = compute_prf1_metrics(slots_pairs)
return IntentSlotMetrics(
intent_metrics=intent_metrics,
slot_metrics=slot_metrics,
overall_metrics=(
intent_confusions.confusions + slot_confusions.confusions
).compute_metrics()
if overall_metrics
else None,
)
[docs]def compute_top_intent_accuracy(frame_pairs: Sequence[FramePredictionPair]) -> float:
"""
Computes accuracy of the top-level intent.
Args:
frame_pairs: List of predicted and gold intent frames.
Returns:
Prediction accuracy of the top-level intent.
"""
num_correct = 0
num_samples = len(frame_pairs)
for (predicted_frame, expected_frame) in frame_pairs:
num_correct += int(predicted_frame.label == expected_frame.label)
return safe_division(num_correct, num_samples)
[docs]def compute_frame_accuracy(frame_pairs: Sequence[FramePredictionPair]) -> float:
"""
Computes frame accuracy given a list of predicted and gold intent frames.
Args:
frame_pairs: List of predicted and gold intent frames.
Returns:
Frame accuracy. For a prediction, frame accuracy is achieved if the entire tree
structure of the predicted frame matches that of the gold frame.
"""
num_correct = 0
num_samples = len(frame_pairs)
for (predicted_frame, expected_frame) in frame_pairs:
num_correct += int(predicted_frame == expected_frame)
return safe_division(num_correct, num_samples)
[docs]def compute_frame_accuracy_top_k(
frame_pairs: List[FramePredictionPair], all_frames: List[List[Node]]
) -> float:
num_samples = len(frame_pairs)
num_correct = 0
for i, top_k_predicted_frames in enumerate(all_frames):
_, expected_frame = frame_pairs[i]
for predicted_frame in top_k_predicted_frames:
if predicted_frame == expected_frame:
num_correct += 1
break
return safe_division(num_correct, num_samples)
[docs]def compute_metric_at_k(
references: List[Node],
hypothesis: List[List[Node]],
metric_fn: Callable[[Node, Node], bool] = lambda f1, f2: f1 == f2,
) -> List[float]:
"""
Computes a boolean metric at each position in the ranked list of hypothesis,
and returns an average for each position over all examples.
By default metric_fn is comparing if frames are equal.
"""
num_samples = len(references)
# Position of the correct frame if present in the ranked list.
pos_correct = []
max_hyp_count = 0
# Iterate over ranked list of hypothesis and remember what was the position
# of the first correct frame.
for reference, hyp_list in zip(references, hypothesis):
correct_index = -1
for rank, predicted_frame in enumerate(hyp_list):
max_hyp_count = max(rank, max_hyp_count)
if metric_fn(predicted_frame, reference):
correct_index = rank
break
pos_correct.append(correct_index)
res = [0] * max_hyp_count
# Compute the number of correct frames per each position in the ranked list.
for pos in pos_correct:
if pos >= 0:
for i in range(pos, max_hyp_count):
res[i] += 1
return [safe_division(res_i, num_samples) for res_i in res]
[docs]def compute_frame_accuracies_by_depth(
frame_pairs: Sequence[FramePredictionPair],
) -> FrameAccuraciesByDepth:
"""
Given a list of predicted and gold intent frames, splits the predictions into
buckets according to the depth of the gold trees, and computes frame accuracy for
each bucket.
Args:
frame_pairs: List of predicted and gold intent frames.
Returns:
FrameAccuraciesByDepth, a map from depths to their corresponding frame
accuracies.
"""
frame_pairs_by_depth: Dict[int, List[FramePredictionPair]] = defaultdict(list)
for frame_pair in frame_pairs:
depth = frame_pair.expected_frame.get_depth()
frame_pairs_by_depth[depth].append(frame_pair)
frame_accuracies_by_depth: FrameAccuraciesByDepth = {}
for depth, pairs in frame_pairs_by_depth.items():
frame_accuracies_by_depth[depth] = FrameAccuracy(
len(pairs), compute_frame_accuracy(pairs)
)
return frame_accuracies_by_depth
[docs]def compute_all_metrics(
frame_pairs: Sequence[FramePredictionPair],
top_intent_accuracy: bool = True,
frame_accuracy: bool = True,
frame_accuracies_by_depth: bool = True,
bracket_metrics: bool = True,
tree_metrics: bool = True,
overall_metrics: bool = False,
all_predicted_frames: List[List[Node]] = None,
calculated_loss: float = None,
length_metrics: Dict = None,
) -> AllMetrics:
"""
Given a list of predicted and gold intent frames, computes intent-slot related
metrics.
Args:
frame_pairs: List of predicted and gold intent frames.
top_intent_accuracy: Whether to compute top intent accuracy or not. Defaults to
True.
frame_accuracy: Whether to compute frame accuracy or not. Defaults to True.
frame_accuracies_by_depth: Whether to compute frame accuracies by depth or not.
Defaults to True.
bracket_metrics: Whether to compute bracket metrics or not. Defaults to True.
tree_metrics: Whether to compute tree metrics or not. Defaults to True.
overall_metrics: If `bracket_metrics` or `tree_metrics` is true, decides whether
to compute overall (merging intents and slots) metrics for them. Defaults to
False.
Returns:
AllMetrics which contains intent-slot related metrics.
"""
frame_accuracy_top_k = 0
if all_predicted_frames:
frame_accuracy_top_k = compute_frame_accuracy_top_k(
frame_pairs, all_predicted_frames
)
top_intent = (
compute_top_intent_accuracy(frame_pairs) if top_intent_accuracy else None
)
accuracy = compute_frame_accuracy(frame_pairs) if frame_accuracy else None
accuracies = (
compute_frame_accuracies_by_depth(frame_pairs)
if frame_accuracies_by_depth
else None
)
bracket = (
compute_intent_slot_metrics(
frame_pairs, tree_based=False, overall_metrics=overall_metrics
)
if bracket_metrics
else None
)
tree = (
compute_intent_slot_metrics(
frame_pairs, tree_based=True, overall_metrics=overall_metrics
)
if tree_metrics
else None
)
return AllMetrics(
top_intent,
accuracy,
frame_accuracy_top_k,
accuracies,
bracket,
tree,
calculated_loss,
)