#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List
import torch
import torch.jit as jit
import torch.nn as nn
from caffe2.python.crf_predict import apply_crf
from pytext.utils.usage import log_class_usage
[docs]class CRF(nn.Module):
"""
Compute the log-likelihood of the input assuming a conditional random field
model.
Args:
num_tags: The number of tags
"""
def __init__(
self, num_tags: int, ignore_index: int, default_label_pad_index: int
) -> None:
if num_tags <= 0:
raise ValueError(f"Invalid number of tags: {num_tags}")
super().__init__()
self.num_tags = num_tags
# Add two states at the end to accommodate start and end states
# (i,j) element represents the probability of transitioning from state i to j
self.transitions = nn.Parameter(torch.Tensor(num_tags + 2, num_tags + 2))
self.start_tag = num_tags
self.end_tag = num_tags + 1
self.reset_parameters()
self.ignore_index = ignore_index
self.default_label_pad_index = default_label_pad_index
log_class_usage(__class__)
[docs] def reset_parameters(self) -> None:
nn.init.uniform_(self.transitions, -0.1, 0.1)
self.transitions.data[:, self.start_tag] = -10000
self.transitions.data[self.end_tag, :] = -10000
[docs] def get_transitions(self):
return self.transitions.data
[docs] def set_transitions(self, transitions: torch.Tensor = None):
self.transitions.data = transitions
[docs] def forward(
self, emissions: torch.Tensor, tags: torch.Tensor, reduce: bool = True
) -> torch.Tensor:
"""
Compute log-likelihood of input.
Args:
emissions: Emission values for different tags for each input. The
expected shape is batch_size * seq_len * num_labels. Padding is
should be on the right side of the input.
tags: Actual tags for each token in the input. Expected shape is
batch_size * seq_len
"""
mask = self._make_mask_from_targets(tags)
numerator = self._compute_joint_llh(emissions, tags, mask)
denominator = self._compute_log_partition_function(emissions, mask)
llh = numerator - denominator
return llh if not reduce else torch.mean(llh)
[docs] @jit.export
def decode(self, emissions: torch.Tensor, seq_lens: torch.Tensor) -> torch.Tensor:
"""
Given a set of emission probabilities, return the predicted tags.
Args:
emissions: Emission probabilities with expected shape of
batch_size * seq_len * num_labels
seq_lens: Length of each input.
"""
mask = self._make_mask_from_seq_lens(seq_lens)
result = self._viterbi_decode(emissions, mask)
return result
def _compute_joint_llh(
self, emissions: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
seq_len = emissions.shape[1]
# Log-likelihood for a given input is calculated by using the known
# correct tag for each timestep and its respective emission value.
# Since actual tags for each time step is also known, sum of transition
# probabilities is also calculated.
# Sum of emission and transition probabilities gives the final score for
# the input.
llh = self.transitions[self.start_tag, tags[:, 0]].unsqueeze(1)
llh += emissions[:, 0, :].gather(1, tags[:, 0].view(-1, 1)) * mask[
:, 0
].unsqueeze(1)
for idx in range(1, seq_len):
old_state, new_state = (
tags[:, idx - 1].view(-1, 1),
tags[:, idx].view(-1, 1),
)
emission_scores = emissions[:, idx, :].gather(1, new_state)
transition_scores = self.transitions[old_state, new_state]
llh += (emission_scores + transition_scores) * mask[:, idx].unsqueeze(1)
# Index of the last tag is calculated by taking the sum of mask matrix
# for each input row and subtracting 1 from the sum.
last_tag_indices = mask.sum(1, dtype=torch.long) - 1
last_tags = tags.gather(1, last_tag_indices.view(-1, 1))
llh += self.transitions[last_tags.squeeze(1), self.end_tag].unsqueeze(1)
return llh.squeeze(1)
def _compute_log_partition_function(
self, emissions: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
seq_len = emissions.shape[1]
log_prob = emissions[:, 0].clone()
log_prob += self.transitions[self.start_tag, : self.start_tag].unsqueeze(0)
for idx in range(1, seq_len):
broadcast_emissions = emissions[:, idx].unsqueeze(1)
broadcast_transitions = self.transitions[
: self.start_tag, : self.start_tag
].unsqueeze(0)
broadcast_logprob = log_prob.unsqueeze(2)
score = broadcast_logprob + broadcast_emissions + broadcast_transitions
score = torch.logsumexp(score, 1)
log_prob = score * mask[:, idx].unsqueeze(1) + log_prob.squeeze(1) * (
1 - mask[:, idx].unsqueeze(1)
)
log_prob += self.transitions[: self.start_tag, self.end_tag].unsqueeze(0)
return torch.logsumexp(log_prob.squeeze(1), 1)
def _viterbi_decode(
self, emissions: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
tensor_device = emissions.device
seq_len = emissions.shape[1]
mask = mask.to(torch.uint8)
log_prob = emissions[:, 0].clone()
log_prob += self.transitions[self.start_tag, : self.start_tag].unsqueeze(0)
# At each step, we need to keep track of the total score, as if this step
# was the last valid step.
end_scores = log_prob + self.transitions[
: self.start_tag, self.end_tag
].unsqueeze(0)
best_scores_list: List[torch.Tensor] = []
# Needed for Torchscript as empty list is assumed to be list of tensors
empty_data: List[int] = []
# If the element has only token, empty tensor in best_paths helps
# torch.cat() from crashing
best_paths_list = [torch.tensor(empty_data, device=tensor_device).long()]
best_scores_list.append(end_scores.unsqueeze(1))
for idx in range(1, seq_len):
broadcast_emissions = emissions[:, idx].unsqueeze(1)
broadcast_transmissions = self.transitions[
: self.start_tag, : self.start_tag
].unsqueeze(0)
broadcast_log_prob = log_prob.unsqueeze(2)
score = broadcast_emissions + broadcast_transmissions + broadcast_log_prob
max_scores, max_score_indices = torch.max(score, 1)
best_paths_list.append(max_score_indices.unsqueeze(1))
# Storing the scores incase this was the last step.
end_scores = max_scores + self.transitions[
: self.start_tag, self.end_tag
].unsqueeze(0)
best_scores_list.append(end_scores.unsqueeze(1))
log_prob = max_scores
best_scores = torch.cat(best_scores_list, 1).float()
best_paths = torch.cat(best_paths_list, 1)
_, max_indices_from_scores = torch.max(best_scores, 2)
valid_index_tensor = torch.tensor(0, device=tensor_device).long()
if self.ignore_index == self.default_label_pad_index:
# No label for padding, so use 0 index.
padding_tensor = valid_index_tensor
else:
padding_tensor = torch.tensor(
self.ignore_index, device=tensor_device
).long()
# Label for the last position is always based on the index with max score
# For illegal timesteps, we set as ignore_index
labels = max_indices_from_scores[:, seq_len - 1]
labels = self._mask_tensor(labels, 1 - mask[:, seq_len - 1], padding_tensor)
all_labels = labels.unsqueeze(1).long()
# For Viterbi decoding, we start at the last position and go towards first
for idx in range(seq_len - 2, -1, -1):
# There are two ways to obtain labels for tokens at a particular position.
# Option 1: Use the labels obtained from the previous position to index
# the path in present position. This is used for all positions except
# last position in the sequence.
# Option 2: Find the indices with maximum scores obtained during
# viterbi decoding. This is used for the token at the last position
# For option 1 need to convert invalid indices to 0 so that lookups
# dont fail.
indices_for_lookup = all_labels[:, -1].clone()
indices_for_lookup = self._mask_tensor(
indices_for_lookup,
indices_for_lookup == self.ignore_index,
valid_index_tensor,
)
# Option 1 is used here when previous timestep (idx+1) was valid.
indices_from_prev_pos = (
best_paths[:, idx, :]
.gather(1, indices_for_lookup.view(-1, 1).long())
.squeeze(1)
)
indices_from_prev_pos = self._mask_tensor(
indices_from_prev_pos, (1 - mask[:, idx + 1]), padding_tensor
)
# Option 2 is used when last timestep was not valid which means idx+1
# is the last position in the sequence.
indices_from_max_scores = max_indices_from_scores[:, idx]
indices_from_max_scores = self._mask_tensor(
indices_from_max_scores, mask[:, idx + 1], padding_tensor
)
# We need to combine results from 1 and 2 as rows in a batch can have
# sequences of varying lengths
labels = torch.where(
indices_from_max_scores == self.ignore_index,
indices_from_prev_pos,
indices_from_max_scores,
)
# Set to ignore_index if present state is not valid.
labels = self._mask_tensor(labels, (1 - mask[:, idx]), padding_tensor)
all_labels = torch.cat((all_labels, labels.view(-1, 1).long()), 1)
return torch.flip(all_labels, [1])
def _make_mask_from_targets(self, targets):
mask = targets.ne(self.ignore_index).float()
return mask
def _make_mask_from_seq_lens(self, seq_lens):
seq_lens = seq_lens.view(-1, 1)
max_len = torch.max(seq_lens)
range_tensor = torch.arange(max_len, device=seq_lens.device).unsqueeze(0)
range_tensor = range_tensor.expand(seq_lens.size(0), range_tensor.size(1))
mask = (range_tensor < seq_lens).float()
return mask
def _mask_tensor(self, score_tensor, mask_condition, mask_value):
masked_tensor = torch.where(mask_condition, mask_value, score_tensor)
return masked_tensor
[docs] def export_to_caffe2(self, workspace, init_net, predict_net, logits_output_name):
"""
Exports the crf layer to caffe2 by manually adding the necessary operators
to the init_net and predict net.
Args:
init_net: caffe2 init net created by the current graph
predict_net: caffe2 net created by the current graph
workspace: caffe2 current workspace
output_names: current output names of the caffe2 net
py_model: original pytorch model object
Returns:
string: The updated predictions blob name
"""
crf_transitions = init_net.AddExternalInput(init_net.NextName())
workspace.FeedBlob(str(crf_transitions), self.get_transitions().numpy())
logits_squeezed = predict_net.Squeeze(logits_output_name, dims=[0])
new_logits = apply_crf(
init_net, predict_net, crf_transitions, logits_squeezed, self.num_tags
)
new_logits = predict_net.ExpandDims(new_logits, dims=[0])
predict_net.Copy(new_logits, logits_output_name)
return logits_output_name