Source code for pytext.torchscript.seq2seq.beam_decode

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Tuple

import torch
from torch import Tensor


[docs]class BeamDecode(torch.nn.Module): """ Decodes the output of Beam Search to get the top hypotheses """ def __init__(self, eos_token_id, length_penalty, nbest, beam_size, stop_at_eos): super().__init__() self.eos_token_id: int = eos_token_id self.length_penalty: float = length_penalty self.nbest: int = nbest self.beam_size: int = beam_size self.stop_at_eos: bool = stop_at_eos
[docs] @torch.no_grad() def forward( self, beam_tokens: Tensor, beam_scores: Tensor, token_weights: Tensor, beam_prev_indices: Tensor, num_steps: int, ) -> List[Tuple[Tensor, float, List[float], Tensor, Tensor]]: self._check_dimensions( beam_tokens, beam_scores, token_weights, beam_prev_indices, num_steps ) end_states = self._get_all_end_states( beam_tokens, beam_scores, beam_prev_indices, num_steps ) # outputs is list of the following for each hypothesis: # Tuple[Hypothesis, Hypothesis score, Token level scores, # Attention Weights, Best indices] outputs = torch.jit.annotate( List[Tuple[Tensor, float, List[float], Tensor, Tensor]], [] ) for state_idx in range(len(end_states)): state = end_states[state_idx] hypothesis_score = float(state[0]) beam_indices = self._get_output_steps_to_beam_indices( state, beam_prev_indices ) beam_output = torch.jit.annotate(List[Tensor], []) token_level_scores = torch.jit.annotate(List[float], []) position = int(state[1]) hyp_index = int(state[2]) # best_indices represents the ending position of one hypothesis, # the first index corresponds num_step, the second corresponds beam_index best_indices = torch.tensor([position, hyp_index]) back_alignment_weights = [] assert position + 1 == len(beam_indices) pos = 1 prev_beam_index = -1 while pos < len(beam_indices): beam_index = beam_indices[pos] beam_output.append(beam_tokens[pos][beam_index]) if pos == 1: # beam_scores[0][:] are all 0s token_level_scores.append(float(beam_scores[pos][beam_index])) else: token_level_scores.append( float(beam_scores[pos][beam_index]) - float(beam_scores[pos - 1][prev_beam_index]) ) back_alignment_weights.append(token_weights[pos][beam_index].detach()) prev_beam_index = beam_index pos += 1 outputs.append( ( torch.stack(beam_output), hypothesis_score, token_level_scores, torch.stack(back_alignment_weights, dim=1), best_indices, ) ) return outputs
def _get_output_steps_to_beam_indices( self, end_state: Tensor, beam_prev_indices: Tensor ) -> List[int]: """ Returns a mapping from each output position and the beam index that was picked from the beam search results. """ present_position = int(end_state[1]) beam_index = int(end_state[2]) beam_indices = torch.jit.annotate(List[int], []) while present_position >= 0: beam_indices.insert(0, beam_index) beam_index = beam_prev_indices[present_position][beam_index] present_position = present_position - 1 return beam_indices def _add_to_end_states( self, end_states: List[Tensor], min_score: float, state: Tensor, min_index: int ) -> Tuple[List[Tensor], float, int]: """ Maintains a list of atmost `nbest` highest end states """ if len(end_states) < self.nbest: end_states.append(state) # keep min_score and min_index updated if state[0] <= min_score: min_score = state[0] min_index = len(end_states) - 1 elif state[0] > min_score: # replace worst hypo with the new one end_states[min_index] = state # find new worst hypo, keep min_score and min_index updated min_index = -1 # not using float("inf") temporarily bc of TorchScript bug # using max representable value in fp16 min_score = 65504.0 for idx in range(len(end_states)): s = end_states[idx] if s[0] <= min_score: min_index = idx min_score = s[0] return end_states, min_score, min_index def _get_all_end_states( self, beam_tokens: Tensor, beam_scores: Tensor, beam_prev_indices: Tensor, num_steps: int, ) -> Tensor: """ Return all end states and hypothesis scores for those end states. """ # not using float("inf") temporarily bc of TorchScript bug # using max representable value in fp16 min_score = 65504.0 min_index = -1 end_states = torch.jit.annotate(List[Tensor], []) prev_hypo_is_finished = torch.zeros(self.beam_size).byte() position = 1 while position <= num_steps: hypo_is_finished = torch.zeros(self.beam_size, dtype=torch.bool) for hyp_index in range(self.beam_size): prev_pos = beam_prev_indices[position][hyp_index] hypo_is_finished[hyp_index] = prev_hypo_is_finished[prev_pos] # If hypothesis was completed in the previous index, # then just continue if hypo_is_finished[hyp_index] == 0: # If the present token is EOS or we have reached max_length # then hypothesis is complete if (beam_tokens[position][hyp_index] == self.eos_token_id) or ( position == num_steps ): if self.stop_at_eos: hypo_is_finished[hyp_index] = 1 hypo_score = float(beam_scores[position][hyp_index]) if self.length_penalty != 0: hypo_score = hypo_score / (position ** self.length_penalty) end_states, min_score, min_index = self._add_to_end_states( end_states, min_score, torch.tensor( [hypo_score, float(position), float(hyp_index)] ), min_index, ) prev_hypo_is_finished = hypo_is_finished position = position + 1 end_states = torch.stack(end_states) _, sorted_end_state_indices = end_states[:, 0].sort(dim=0, descending=True) end_states = end_states[sorted_end_state_indices, :] return end_states def _check_dimensions( self, beam_tokens: Tensor, beam_scores: Tensor, token_weights: Tensor, beam_prev_indices: Tensor, num_steps: int, ) -> None: assert ( beam_tokens.size(1) == self.beam_size ), "Dimension of beam_tokens : {} and beam size : {} are not consistent".format( beam_tokens.size(), self.beam_size ) assert beam_scores.size(1) == self.beam_size, ( "Dimension of beam_scores : {} and beam size : {} " "are not consistent".format(beam_scores.size(), self.beam_size) ) assert token_weights.size(1) == self.beam_size, ( "Dimension of token_weights : {} and beam size : {} " "are not consistent".format(token_weights.size(), self.beam_size) ) assert ( beam_prev_indices.size(1) == self.beam_size ), "Dimension of beam_prev_indices : {} and beam size : {} " "are not consistent".format(beam_prev_indices.size(), self.beam_size) assert beam_tokens.size(0) <= num_steps + 1, ( "Dimension of beam_tokens : {} and num_steps : {} " "are not consistent".format(beam_tokens.size(), num_steps) ) assert beam_scores.size(0) <= num_steps + 1, ( "Dimension of beam_scores : {} and num_steps : {} " "are not consistent".format(beam_scores.size(), num_steps) ) assert token_weights.size(0) <= num_steps + 1, ( "Dimension of token_weights : {} and num_steps : {} " "are not consistent".format(token_weights.size(), num_steps) ) assert beam_prev_indices.size(0) <= num_steps + 1, ( "Dimension of beam_prev_indices : {} and num_steps : {} " "are not consistent".format(beam_prev_indices.size(), num_steps) )