Source code for pytext.models.seq_models.mask_generator

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from enum import Enum
from typing import Dict, Optional, Tuple

import torch
from pytext.config import ConfigBase
from pytext.models.module import Module
from pytext.torchscript.vocab import ScriptVocabulary
from torch import Tensor
from torch.quantization import float_qparams_weight_only_qconfig
from torch.quantization import (
    prepare,
    convert,
)


[docs]class BeamRankingAlgorithm(Enum): LENGTH_CONDITIONED_RANK: str = "LENGTH_CONDITIONED_RANK" LENGTH_CONDITIONED_RANK_MUL: str = "LENGTH_CONDITIONED_RANK_MUL" AVERAGE_TOKEN_LPROB: str = "AVERAGE_TOKEN_LPROB" TOKEN_LPROB: str = "TOKEN_LPROB" LENGTH_CONDITIONED_AVERAGE_TOKEN_LPROB: str = ( "LENGTH_CONDITIONED_AVERAGE_TOKEN_LPROB" ) LENGTH_CONDITIONED_AVERAGE_TOKEN_LPROB_MULTIPLIED: str = ( "LENGTH_CONDITIONED_AVERAGE_TOKEN_LPROB_MULTIPLIED" ) LEN_ONLY: str = "LEN_ONLY"
# Sum of token prob and length prob
[docs]def length_conditioned_rank( token_lprob: torch.Tensor, length_lprob: torch.Tensor, target_lengths: torch.Tensor ) -> torch.Tensor: return token_lprob + length_lprob
# Sum of token prob + length * length_prob
[docs]def length_conditioned_rank_mul( token_lprob: torch.Tensor, length_lprob: torch.Tensor, target_lengths: torch.Tensor ) -> torch.Tensor: return token_lprob + target_lengths * length_lprob
# Sum of token prod
[docs]def token_prob( token_lprob: torch.Tensor, length_lprob: torch.Tensor, target_lengths: torch.Tensor ) -> torch.Tensor: return token_lprob
# Only length_prob
[docs]def length( token_lprob: torch.Tensor, length_lprob: torch.Tensor, target_lengths: torch.Tensor ) -> torch.Tensor: return length_lprob
# Avg token prob
[docs]def avg_token_lprob( token_lprob: torch.Tensor, length_lprob: torch.Tensor, target_lengths: torch.Tensor ) -> torch.Tensor: avg_log_prob = token_lprob / target_lengths.to(token_lprob.dtype) return avg_log_prob
# Avg token prob + length prob
[docs]def length_conditioned_avg_lprob_rank( token_lprob: torch.Tensor, length_lprob: torch.Tensor, target_lengths: torch.Tensor ) -> torch.Tensor: avg_token_lprob_tensor = avg_token_lprob(token_lprob, length_lprob, target_lengths) return avg_token_lprob_tensor + length_lprob
# Avg token prob + len * length_prob
[docs]def length_conditioned_avg_lprob_rank_mul( token_lprob: torch.Tensor, length_lprob: torch.Tensor, target_lengths: torch.Tensor ) -> torch.Tensor: avg_token_lprob_tensor = avg_token_lprob(token_lprob, length_lprob, target_lengths) return avg_token_lprob_tensor + target_lengths * length_lprob
[docs]def get_beam_ranking_function(ranking_algorithm: BeamRankingAlgorithm): if ranking_algorithm == BeamRankingAlgorithm.LENGTH_CONDITIONED_RANK: return length_conditioned_rank elif ranking_algorithm == BeamRankingAlgorithm.LENGTH_CONDITIONED_RANK_MUL: return length_conditioned_rank_mul elif ranking_algorithm == BeamRankingAlgorithm.AVERAGE_TOKEN_LPROB: return avg_token_lprob elif ( ranking_algorithm == BeamRankingAlgorithm.LENGTH_CONDITIONED_AVERAGE_TOKEN_LPROB ): return length_conditioned_avg_lprob_rank elif ( ranking_algorithm == BeamRankingAlgorithm.LENGTH_CONDITIONED_AVERAGE_TOKEN_LPROB_MULTIPLIED ): return length_conditioned_avg_lprob_rank_mul elif ranking_algorithm == BeamRankingAlgorithm.TOKEN_LPROB: return token_prob elif ranking_algorithm == BeamRankingAlgorithm.LEN_ONLY: return length else: raise Exception(f"Unknown ranking algorithm {ranking_algorithm}")
[docs]def prepare_masked_target_for_lengths( beam: Tensor, mask_idx: int, pad_idx: int, length_beam_size: int = 1 ) -> Tuple[Tensor, Tensor]: # beam : bsz X beam_size max_len = beam.max().item() bsz = beam.size(0) # length_mask[sample_length] will give a row vector of sample_length+1 ones # and rest zeros length_mask = torch.triu( torch.ones(max_len, max_len, device=beam.device, dtype=beam.dtype), diagonal=1, ) beam_indices = beam - 1 length_mask = length_mask[beam_indices.reshape(-1)].reshape( bsz, length_beam_size, max_len ) tgt_tokens = torch.zeros( bsz, length_beam_size, max_len, device=beam.device, dtype=beam.dtype ).fill_(mask_idx) tgt_tokens = (1 - length_mask) * tgt_tokens + length_mask * pad_idx tgt_tokens = tgt_tokens.view(bsz * length_beam_size, max_len) return tgt_tokens, length_mask
[docs]class EmbedQuantizeType(Enum): BIT_8 = "8bit" BIT_4 = "4bit" NONE = "None"
[docs]class MaskedSequenceGenerator(Module): class Config(ConfigBase): beam_size: int = 3 quantize: bool = True embed_quantize: EmbedQuantizeType = EmbedQuantizeType.NONE use_gold_length: bool = False force_eval_predictions: bool = True generate_predictions_every: int = 1 beam_ranking_algorithm: BeamRankingAlgorithm = ( BeamRankingAlgorithm.LENGTH_CONDITIONED_RANK ) clip_target_length: bool = False # We use a quardratic formula to generate the max target length # min(targetlen_cap, targetlen_a*x^2 + targetlen_b*x + targetlen_c) targetlen_cap: int = 30 targetlen_a: float = 0 targetlen_b: float = 2 targetlen_c: float = 2
[docs] @classmethod def from_config( cls, config, model, length_prediction, trg_vocab, quantize=False, embed_quantize=False, ): return cls( config, model, length_prediction, trg_vocab, config.beam_size, config.use_gold_length, config.beam_ranking_algorithm, quantize, config.embed_quantize, )
def __init__( self, config, model, length_prediction_model, trg_vocab, beam_size, use_gold_length, beam_ranking_algorithm, quantize, embed_quantize, ): super().__init__() if quantize: self.model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear: torch.quantization.per_channel_dynamic_qconfig}, dtype=torch.qint8, inplace=False, ) # embedding quantization if embed_quantize != EmbedQuantizeType.NONE: # 8-bit embedding quantization if embed_quantize == EmbedQuantizeType.BIT_8: ## identify nn.Embedding for module in self.model.modules(): if isinstance(module, torch.nn.Embedding): module.qconfig = float_qparams_weight_only_qconfig prepare(self.model, inplace=True) convert(self.model, inplace=True) # 4-bit embedding quantization elif embed_quantize == EmbedQuantizeType.BIT_4: raise NotImplementedError( "4bit embedding quantization not yet supported" ) else: raise NotImplementedError( "Embedding Quantization should be either 8bit or 4bit" ) self.length_prediction_model = torch.quantization.quantize_dynamic( length_prediction_model, {torch.nn.Linear: torch.quantization.per_channel_dynamic_qconfig}, dtype=torch.qint8, inplace=False, ) else: self.model = model self.length_prediction_model = length_prediction_model self.trg_vocab = ScriptVocabulary( list(trg_vocab), pad_idx=trg_vocab.get_pad_index(), bos_idx=trg_vocab.get_bos_index(-1), eos_idx=trg_vocab.get_eos_index(-1), mask_idx=trg_vocab.get_mask_index(), ) self.length_beam_size = beam_size self.use_gold_length = use_gold_length self.beam_ranking_algorithm = get_beam_ranking_function( ranking_algorithm=beam_ranking_algorithm ) self.clip_target_length = config.clip_target_length self.targetlen_cap = config.targetlen_cap self.targetlen_a = config.targetlen_a self.targetlen_b = config.targetlen_b self.targetlen_c = config.targetlen_c
[docs] def get_encoder_out( self, src_tokens: Tensor, dict_feats: Optional[Tuple[Tensor, Tensor, Tensor]], contextual_embed: Optional[Tensor], char_feats: Optional[Tensor], src_subword_begin_indices: Optional[Tensor], src_lengths: Tensor, src_index_tokens: Optional[Tensor] = None, ) -> Dict[str, Tensor]: embedding_input = [[src_tokens]] if dict_feats is not None: embedding_input.append(list(dict_feats)) if contextual_embed is not None: embedding_input.append([contextual_embed]) if char_feats is not None: embedding_input.append([char_feats]) embeddings = self.model.source_embeddings(embedding_input) encoder_out = self.model.encoder( src_tokens, embeddings, src_lengths=src_lengths ) if src_index_tokens is not None: encoder_out["src_index_tokens"] = src_index_tokens return encoder_out
[docs] def forward( self, src_tokens: Tensor, dict_feats: Optional[Tuple[Tensor, Tensor, Tensor]], contextual_embed: Optional[Tensor], char_feats: Optional[Tensor], src_lengths: Tensor, src_subword_begin_indices: Optional[Tensor] = None, target_lengths: Optional[Tensor] = None, beam_size: Optional[int] = None, src_index_tokens: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: encoder_out = self.get_encoder_out( src_tokens=src_tokens, dict_feats=dict_feats, contextual_embed=contextual_embed, char_feats=char_feats, src_subword_begin_indices=src_subword_begin_indices, src_lengths=src_lengths, src_index_tokens=src_index_tokens, ) encoder_mask: Optional[Tensor] = None if "encoder_mask" in encoder_out: encoder_mask = encoder_out["encoder_mask"] predicted_tgt_length, _ = self.length_prediction_model( encoder_out["encoder_out"], encoder_mask ) if beam_size is not None and beam_size != self.length_beam_size: self.length_beam_size = beam_size if self.clip_target_length: beam_vals, beam = predicted_tgt_length.topk(self.length_beam_size, dim=1) len_clips = self.get_clip_length(src_lengths) len_clips = len_clips.reshape(-1, 1).repeat(1, self.length_beam_size) acceptable_lens_mask = torch.le(beam, len_clips) beam = beam * acceptable_lens_mask + torch.logical_not( acceptable_lens_mask ) * torch.ones_like(beam, dtype=beam.dtype, device=beam.device) beam_vals = beam_vals * acceptable_lens_mask + torch.logical_not( acceptable_lens_mask ) * torch.full( beam_vals.size(), float("-inf"), dtype=beam_vals.dtype, device=beam_vals.device, ) else: beam_vals, beam = predicted_tgt_length.topk(self.length_beam_size, dim=1) # make sure no beams are 0 (integration test) beam[beam == 0] += 1 length_prob = torch.gather(predicted_tgt_length, 1, beam) if self.use_gold_length: assert target_lengths is not None beam = target_lengths.reshape(-1, 1) self.length_beam_size = 1 length_prob = torch.ones(beam.size(), device=target_lengths.device) tgt_tokens, length_mask = prepare_masked_target_for_lengths( beam, self.trg_vocab.mask_idx, self.trg_vocab.pad_idx, self.length_beam_size, ) bsz = src_tokens.size(0) max_len = tgt_tokens.size(1) tiled_encoder_out = self.model.encoder.prepare_for_nar_inference( self.length_beam_size, encoder_out ) # OneStep Generation pad_mask = tgt_tokens.eq(self.trg_vocab.pad_idx) tgt_tokens, token_probs = self.generate_non_autoregressive( tiled_encoder_out, tgt_tokens ) tgt_tokens[pad_mask] = torch.tensor( self.trg_vocab.pad_idx, device=tgt_tokens.device ).long() token_probs[pad_mask] = torch.tensor( 1.0, device=token_probs.device, dtype=token_probs.dtype ) token_probs = token_probs.view(bsz, self.length_beam_size, max_len).log() lprobs = token_probs.sum(-1) hypotheses = tgt_tokens.view(bsz, self.length_beam_size, max_len) lprobs = lprobs.view(bsz, self.length_beam_size) tgt_lengths = (1 - length_mask).sum(-1) hyp_score = self.beam_ranking_algorithm( token_lprob=lprobs, length_lprob=length_prob, target_lengths=tgt_lengths ) sorted_scores, indices = torch.sort(hyp_score, dim=-1, descending=True) all_indices = torch.arange(bsz).unsqueeze(-1) hypotheses = hypotheses[all_indices, indices] return hypotheses, beam, sorted_scores.exp(), token_probs
[docs] def get_clip_length(self, src_lengths: Tensor): predicted = ( torch.tensor( self.targetlen_a, dtype=src_lengths.dtype, device=src_lengths.device ) * src_lengths * src_lengths + torch.tensor( self.targetlen_b, dtype=src_lengths.dtype, device=src_lengths.device ) * src_lengths + torch.tensor( self.targetlen_c, dtype=src_lengths.dtype, device=src_lengths.device ) ) capped = torch.min( predicted, torch.tensor( self.targetlen_cap, dtype=src_lengths.dtype, device=src_lengths.device ), ) return capped
[docs] @torch.jit.export def generate_hypo( self, tensors: Dict[str, Tensor] ) -> Tuple[Tuple[Tensor, Tensor], Tensor]: """ Generates hypotheses using beam search, also returning their scores Inputs: - tensors: dictionary containing needed tensors for generation Outputs: - (hypos, lens): tuple of Tensors - hypos: Tensor of shape (batch_size, beam_size, MAX) containing the generated tokens. MAX refers to the longest sequence in batch. - lens: Tensor of shape (batch_size, beam_size) containing generated sequence lengths - _hypo_scores: Tensor of shape (batch_size, beam_size) containing the scores for each generated sequence """ actual_src_tokens = tensors["src_tokens"] dict_feats: Optional[Tuple[Tensor, Tensor, Tensor]] = None contextual_embed: Optional[Tensor] = None char_feats: Optional[Tensor] = None if "dict_tokens" in tensors: dict_feats = ( tensors["dict_tokens"], tensors["dict_weights"], tensors["dict_lengths"], ) if "contextual_embed" in tensors: contextual_embed = tensors["contextual_embed"] if "char_feats" in tensors: char_feats = tensors["char_feats"] hypos, lens, hypo_scores, _token_probs = self.forward( actual_src_tokens, dict_feats, contextual_embed, char_feats, tensors["src_lengths"], src_subword_begin_indices=tensors.get("src_subword_begin_indices"), target_lengths=tensors["target_lengths"], beam_size=self.length_beam_size, src_index_tokens=tensors.get("src_index_tokens"), ) return (hypos, lens), hypo_scores
[docs] def generate_non_autoregressive(self, encoder_out: Dict[str, Tensor], tgt_tokens): decoder_out_tuple = self.model.decoder(tgt_tokens, encoder_out) tgt_tokens, token_probs, _ = self.model.decoder.get_probs(decoder_out_tuple) return tgt_tokens, token_probs