Source code for pytext.metric_reporters.compositional_utils
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Optional, Tuple
[docs]def extract_beam_subtrees(beam: List[List[str]]) -> List[List[str]]:
return list(
filter(
lambda processed_pred: processed_pred is not None,
map(lambda tree_pred: extract_subtree(tree_pred), beam),
)
)
[docs]def filter_invalid_beams(beam: List[List[str]]) -> List[List[str]]:
return list(filter(lambda pred: is_valid_tree(pred), beam))
[docs]def is_valid_tree(beam: List[str]) -> bool:
paren_stack: List[int] = []
for i, token in enumerate(beam):
if len(token) == 0:
continue
if token[0] == "[":
paren_stack.append(i)
elif token[-1] == "]":
if len(paren_stack) == 0:
return False
paren_stack.pop()
return len(paren_stack) == 0
[docs]def extract_subtree(beam: List[str]) -> Optional[List[str]]:
paren_stack: List[int] = []
longest_valid_tree: Tuple[int, int] = (-1, -1)
# determine what the valid subtree is
# in the prediction
for i, token in enumerate(beam):
if len(token) == 0:
# skip over empty tokens since we check
# characters in the token strings
continue
if token[0] == "[":
# if the token begints with "[" it is the
# start of a sequence so appending to stack
paren_stack.append(i)
elif token[-1] == "]":
# found a potential end of a valid subsequence
if len(paren_stack) == 0:
# cannot find a valid tree
# reset stack and continue
paren_stack = []
continue
# the valid subtree is from current position to
# the start sequence at `tree_start`
tree_start: int = paren_stack.pop()
if (i - tree_start) > (longest_valid_tree[1] - longest_valid_tree[0]):
# if this subsequence is longer than the mosdt valid one
# this is the longest valid subsequence
longest_valid_tree = (tree_start, i)
if longest_valid_tree == (-1, -1):
# no valid sequence was found
return None
valid_tree_start, valid_tree_end = longest_valid_tree
return beam[valid_tree_start : valid_tree_end + 1]