Source code for pytext.torchscript.module

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, Dict, List, Optional, Tuple

import torch
from pytext.config import ExportConfig
from pytext.torchscript.batchutils import (
    input_size,
    limit_list,
    clip_list,
    limit_listlist,
    clip_listlist,
    limit_listlist_float,
    clip_listlist_float,
    destructure_tensor,
    destructure_tensor_list,
    destructure_any_list,
    zip_batch_any_list_list,
    zip_batch_tensor_list,
    make_batch_texts_dense,
    make_prediction_texts,
    make_prediction_texts_dense,
    max_tokens,
    nonify_listlist_float,
    validate_dense_feat,
    validate_make_prediction_batch_element,
)
from pytext.torchscript.tensorizer.normalizer import VectorNormalizer
from pytext.torchscript.tensorizer.tensorizer import ScriptTensorizer
from pytext.torchscript.utils import ScriptBatchInput, squeeze_1d, squeeze_2d
from pytext.utils.usage import log_class_usage


@torch.jit.script
def resolve_texts(
    texts: Optional[List[str]] = None, multi_texts: Optional[List[List[str]]] = None
) -> Optional[List[List[str]]]:
    if texts is not None:
        return squeeze_1d(texts)
    return multi_texts


[docs]def deprecation_warning(export_conf: ExportConfig): if export_conf.inference_interface is not None: print( "*************** DEPRECATION ERROR **************" "inference_interface config option is not available" "**************************************************" ) raise RuntimeError("export configuration not supported") elif ( (export_conf.accelerate is not None) or (export_conf.seq_padding_control is not None) or (export_conf.batch_padding_control is not None) ): msg = [ "*********** DEPRECATION WARNING **********", "Modules concurrently supporting untokenized", "and tokenized inputs are being deprecated!", "", "Preferably, use the corresponding Pytext{Type}Module", "hierarchy (sans 'Script') classes to offer models", "including tokenization.", "*********************************************", ] for line in msg: print(line)
[docs]class ScriptPyTextEmbeddingModule(torch.jit.ScriptModule): def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): super().__init__() self.model = model self.tensorizer = tensorizer log_class_usage(self.__class__)
[docs] def validate(self, export_conf: ExportConfig): deprecation_warning(export_conf)
@torch.jit.script_method def set_device(self, device: str): self.tensorizer.set_device(device) @torch.jit.script_method def get_max_seq_len(self) -> int: """ This function returns the maximum sequence length for the model, if it is defined, otherwise raises a Runtime Error. """ if hasattr(self.tensorizer, "max_seq_len"): if self.tensorizer.max_seq_len is not None: return self.tensorizer.max_seq_len raise RuntimeError("max_seq_len not defined") @torch.jit.script_method def get_max_batch_len(self) -> int: """ This function returns the maximum batch length for the model, if it is defined, otherwise -1. """ if hasattr(self.tensorizer, "batch_padding_control"): batch_padding_control = self.tensorizer.batch_padding_control if batch_padding_control is not None: return batch_padding_control[-1] return -1 @torch.jit.script_method def set_padding_control(self, dimension: str, control: Optional[List[int]]): """ This functions will be called to set a padding style. None - No padding List: first element 0, round seq length to the smallest list element larger than inputs """ self.tensorizer.set_padding_control(dimension, control) @torch.jit.script_method def uses_dense_feat(self) -> bool: return False @torch.jit.script_method def forward_validate_dense_feat( self, dense_feat: Optional[List[List[float]]], ) -> List[List[float]]: if self.uses_dense_feat(): if dense_feat is None: raise RuntimeError( "Dense feature (dense_feat) is required for this model type, but not present." ) else: return dense_feat else: if dense_feat is not None: raise RuntimeError( "Dense feature (dense_feat) not allowed for this model type" ) else: return [] @torch.jit.script_method def _forward(self, inputs: ScriptBatchInput): input_tensors = self.tensorizer(inputs) return self.model(input_tensors).cpu() @torch.jit.script_method def forward_impl( self, texts: Optional[List[str]] = None, # multi_texts is of shape [batch_size, num_columns] multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, # self.uses_dense_feat() indicates use: False dense_feat: Optional[List[List[float]]] = None, ) -> torch.Tensor: self.forward_validate_dense_feat(dense_feat) inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, multi_texts), tokens=squeeze_2d(tokens), languages=squeeze_1d(languages), ) return self._forward(inputs) @torch.jit.script_method def forward( self, texts: Optional[List[str]] = None, # multi_texts is of shape [batch_size, num_columns] multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, # self.uses_dense_feat() indicates use: False dense_feat: Optional[List[List[float]]] = None, ): # returns torch.Tensor or List[Any] self.forward_validate_dense_feat(dense_feat) input_len = input_size(texts, multi_texts, tokens) max_batch = self.get_max_batch_len() if max_batch <= 0: max_batch = input_len result = self.forward_impl( limit_list(texts, max_batch), limit_listlist(multi_texts, max_batch), limit_listlist(tokens, max_batch), limit_list(languages, max_batch), limit_listlist_float(dense_feat, max_batch), ) if input_len > max_batch: texts = clip_list(texts, max_batch) multi_texts = clip_listlist(multi_texts, max_batch) tokens = clip_listlist(tokens, max_batch) languages = clip_list(languages, max_batch) dense_feat = clip_listlist_float(dense_feat, max_batch) while input_size(texts, multi_texts, tokens) > 0: result_extension = self.forward_impl( limit_list(texts, max_batch), limit_listlist(multi_texts, max_batch), limit_listlist(tokens, max_batch), limit_list(languages, max_batch), limit_listlist_float(dense_feat, max_batch), ) # the result of forward is either a torch.Tensor or a List[Any] if isinstance(result, torch.Tensor): result = torch.cat([result, result_extension], dim=0) else: result.extend(result_extension) # prepare next iteration texts = clip_list(texts, max_batch) multi_texts = clip_listlist(multi_texts, max_batch) tokens = clip_listlist(tokens, max_batch) languages = clip_list(languages, max_batch) dense_feat = clip_listlist_float(dense_feat, max_batch) if isinstance(result, torch.Tensor): torch._assert( input_len == result.size()[0], "Tensor output size must match input size", ) else: torch._assert( input_len == len(result), "List output size must match input size" ) return result @torch.jit.script_method def make_prediction( self, batch: List[ Tuple[ Optional[List[str]], # texts Optional[List[List[str]]], # multi_texts Optional[List[List[str]]], # tokens Optional[List[str]], # languages Optional[List[List[float]]], # dense_feat must be None ] ], ): # List[torch.Tensor] or List[List[Any]] batchsize = len(batch) client_batch_texts: List[int] = [] client_batch_tokens: List[int] = [] zip_batch_list: List[int] = [] flat_texts: List[str] = [] flat_tokens: List[List[str]] = [] flat_dense_feat_texts: List[List[float]] = [] flat_dense_feat_tokens: List[List[float]] = [] for i in range(batchsize): validate_make_prediction_batch_element(batch[i]) batch_element_texts = batch[i][0] batch_element_tokens = batch[i][2] batch_element_dense_feat = batch[i][4] if batch_element_texts is not None: flat_texts.extend(batch_element_texts) client_batch_texts.append(len(batch_element_texts)) flat_dense_feat_texts.extend( validate_dense_feat( batch_element_dense_feat, len(batch_element_texts), self.uses_dense_feat(), ) ) zip_batch_list.append(1) elif batch_element_tokens is not None: flat_tokens.extend(batch_element_tokens) client_batch_tokens.append(len(batch_element_tokens)) flat_dense_feat_tokens.extend( validate_dense_feat( batch_element_dense_feat, len(batch_element_tokens), self.uses_dense_feat(), ) ) zip_batch_list.append(-1) else: # At present, we abort the entire batch if # any batch element is malformed. # # Possible refinement: # we can skip malformed requests, # and return a list plus an indiction that one or more # batch elements (and which ones) were malformed raise RuntimeError("Malformed request.") if len(flat_texts) == 0 and len(flat_tokens) == 0: raise RuntimeError("This is not good. Empty request batch.") if len(flat_texts) > 0 and len(flat_tokens) > 0: raise RuntimeError("Mixing tokens and texts not supported in this service.") # flat_result_texts = self.forward( # texts=flat_texts[:max_batch], # multi_texts=None, # tokens=None, # languages=None, # dense_feat=nonify_listlist_float(flat_dense_feat_texts), # ) # flat_result_tokens = self.forward( # texts=None, # multi_texts=None, # tokens=flat_tokens[:max_batch], # languages=None, # dense_feat=nonify_listlist_float(flat_dense_feat_tokens), # ) elif len(flat_texts) > 0: flat_result_texts = self.forward( texts=flat_texts, multi_texts=None, tokens=None, languages=None, dense_feat=nonify_listlist_float(flat_dense_feat_texts), ) # ignored in logic, this makes type system happy flat_result_tokens = flat_result_texts else: # len(flat_tokens) > 0: flat_result_tokens = self.forward( texts=None, multi_texts=None, tokens=flat_tokens, languages=None, dense_feat=nonify_listlist_float(flat_dense_feat_tokens), ) # ignored in logic, this makes type system happy flat_result_texts = flat_result_tokens # if torch.jit.isinstance(flat_result_tokens, torch.Tensor): if isinstance(flat_result_tokens, torch.Tensor): # destructure flat result tensor combining # cross-request batches and client side # batches into a cross-request list of # client-side batch tensors return zip_batch_tensor_list( zip_batch_list, destructure_tensor(client_batch_texts, flat_result_texts), destructure_tensor(client_batch_tokens, flat_result_tokens), ) else: # destructure result list of any result type combining # cross-request batches and client side # batches into a cross-request list of # client-side result lists result_texts_any_list: List[Any] = torch.jit.annotate(List[Any], []) for v in flat_result_texts: result_texts_any_list.append(v) result_tokens_any_list: List[Any] = torch.jit.annotate(List[Any], []) for v in flat_result_tokens: result_tokens_any_list.append(v) return zip_batch_any_list_list( zip_batch_list, destructure_any_list(client_batch_texts, result_texts_any_list), destructure_any_list(client_batch_tokens, result_tokens_any_list), ) @torch.jit.script_method def make_batch( self, mega_batch: List[ Tuple[ Optional[List[str]], # texts Optional[List[List[str]]], # multi_texts Optional[List[List[str]]], # tokens Optional[List[str]], # languages Optional[List[List[float]]], # dense_feat must be None int, ] ], goals: Dict[str, str], ) -> List[ List[ Tuple[ Optional[List[str]], # texts Optional[List[List[str]]], # multi_texts Optional[List[List[str]]], # tokens Optional[List[str]], # languages Optional[List[List[float]]], # dense_feat must be None int, ] ] ]: # The next lines sort all cross-request batch elements by the token length. # Note that cross-request batch element can in turn be a client batch. mega_batch_key_list = [ (max_tokens(self.tensorizer.tokenize(x[0], x[2])), n) for (n, x) in enumerate(mega_batch) ] sorted_mega_batch_key_list = sorted(mega_batch_key_list) sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] # TBD: allow model server to specify batch size in goals dictionary max_bs: int = 10 len_mb = len(mega_batch) num_batches = (len_mb + max_bs - 1) // max_bs batch_list: List[ List[ Tuple[ Optional[List[str]], # texts Optional[List[List[str]]], # multi_texts Optional[List[List[str]]], # tokens Optional[List[str]], # language, Optional[List[List[float]]], # dense_feat must be None int, # position ] ] ] = [] start = 0 for _i in range(num_batches): end = min(start + max_bs, len_mb) batch_list.append(sorted_mega_batch[start:end]) start = end return batch_list
[docs]class ScriptPyTextEmbeddingModuleIndex(ScriptPyTextEmbeddingModule): def __init__( self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer, index: int = 0, ): super().__init__(model, tensorizer) self.index: int = index log_class_usage(self.__class__) @torch.jit.script_method def uses_dense_feat(self) -> bool: return False @torch.jit.script_method def _forward(self, inputs: ScriptBatchInput): input_tensors = self.tensorizer(inputs) return self.model(input_tensors)[self.index].cpu()
[docs]class ScriptPyTextModule(ScriptPyTextEmbeddingModule): def __init__( self, model: torch.jit.ScriptModule, output_layer: torch.jit.ScriptModule, tensorizer: ScriptTensorizer, ): super().__init__(model, tensorizer) # A PyText Module is an EmbeddingModule with an output layer self.output_layer = output_layer @torch.jit.script_method def uses_dense_feat(self) -> bool: return False @torch.jit.script_method def forward_impl( self, texts: Optional[List[str]] = None, # multi_texts is of shape [batch_size, num_columns] multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, # self.uses_dense_feat() indicates use: False dense_feat: Optional[List[List[float]]] = None, ): self.forward_validate_dense_feat(dense_feat) inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, multi_texts), tokens=squeeze_2d(tokens), languages=squeeze_1d(languages), ) input_tensors = self.tensorizer(inputs) logits = self.model(input_tensors) return self.output_layer(logits)
[docs]class ScriptPyTextEmbeddingModuleWithDense(ScriptPyTextEmbeddingModule): def __init__( self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer, normalizer: VectorNormalizer, concat_dense: bool = False, ): super().__init__(model, tensorizer) self.normalizer = normalizer self.concat_dense = torch.jit.Attribute(concat_dense, bool) log_class_usage(self.__class__) @torch.jit.script_method def uses_dense_feat(self) -> bool: return True @torch.jit.script_method def _forward(self, inputs: ScriptBatchInput, dense_tensor: torch.Tensor): input_tensors = self.tensorizer(inputs) if self.tensorizer.device != "": dense_tensor = dense_tensor.to(self.tensorizer.device) return self.model(input_tensors, dense_tensor).cpu() @torch.jit.script_method def forward_impl( self, texts: Optional[List[str]] = None, # multi_texts is of shape [batch_size, num_columns] multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, # self.uses_dense_feat() indicates use: True dense_feat: Optional[List[List[float]]] = None, ) -> torch.Tensor: dense_feat = self.forward_validate_dense_feat(dense_feat) inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, multi_texts), tokens=squeeze_2d(tokens), languages=squeeze_1d(languages), ) # call model dense_feat = self.normalizer.normalize(dense_feat) dense_tensor = torch.tensor(dense_feat, dtype=torch.float) sentence_embedding = self._forward(inputs, dense_tensor) if self.concat_dense: return torch.cat([sentence_embedding, dense_tensor], 1) else: return sentence_embedding
[docs]class ScriptPyTextModuleWithDense(ScriptPyTextEmbeddingModuleWithDense): def __init__( self, model: torch.jit.ScriptModule, output_layer: torch.jit.ScriptModule, tensorizer: ScriptTensorizer, normalizer: VectorNormalizer, ): super().__init__(model, tensorizer, normalizer) self.output_layer = output_layer log_class_usage(self.__class__) @torch.jit.script_method def uses_dense_feat(self) -> bool: return True @torch.jit.script_method def forward_impl( self, texts: Optional[List[str]] = None, # multi_texts is of shape [batch_size, num_columns] multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, # self.uses_dense_feat() indicates use: True dense_feat: Optional[List[List[float]]] = None, ): inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, multi_texts), tokens=squeeze_2d(tokens), languages=squeeze_1d(languages), ) input_tensors = self.tensorizer(inputs) dense_feat = self.normalizer.normalize( self.forward_validate_dense_feat(dense_feat) ) dense_tensor = torch.tensor(dense_feat, dtype=torch.float) if self.tensorizer.device != "": dense_tensor = dense_tensor.to(self.tensorizer.device) logits = self.model(input_tensors, dense_tensor) return self.output_layer(logits)
[docs]class ScriptPyTextEmbeddingModuleWithDenseIndex(ScriptPyTextEmbeddingModuleWithDense): def __init__( self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer, normalizer: VectorNormalizer, index: int = 0, concat_dense: bool = True, ): super().__init__(model, tensorizer, normalizer, concat_dense) self.index = torch.jit.Attribute(index, int) log_class_usage(self.__class__) @torch.jit.script_method def uses_dense_feat(self) -> bool: return True @torch.jit.script_method def _forward(self, inputs: ScriptBatchInput, dense_tensor: torch.Tensor): input_tensors = self.tensorizer(inputs) if self.tensorizer.device != "": dense_tensor = dense_tensor.to(self.tensorizer.device) return self.model(input_tensors, dense_tensor)[self.index].cpu()
[docs]class ScriptPyTextVariableSizeEmbeddingModule(ScriptPyTextEmbeddingModule): """ Assumes model returns a tuple of representations and sequence lengths, then slices each example's representation according to length. Returns a list of tensors. The slicing is easier to do outside a traced model. """ def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): super().__init__(model, tensorizer) log_class_usage(self.__class__) @torch.jit.script_method def uses_dense_feat(self) -> bool: return False @torch.jit.script_method def _forward(self, inputs: ScriptBatchInput): input_tensors = self.tensorizer(inputs) reps, seq_lens = self.model(input_tensors) reps = reps.cpu() seq_lens = seq_lens.cpu() return [reps[i, : seq_lens[i]] for i in range(len(seq_lens))] @torch.jit.script_method def forward_impl( self, texts: Optional[List[str]] = None, # multi_texts is of shape [batch_size, num_columns] multi_texts: Optional[List[List[str]]] = None, tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, # self.uses_dense_feat() indicates use: False dense_feat: Optional[List[List[float]]] = None, ) -> List[torch.Tensor]: self.forward_validate_dense_feat(dense_feat) inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, multi_texts), tokens=squeeze_2d(tokens), languages=squeeze_1d(languages), ) return self._forward(inputs)
######################## Two Tower ################################
[docs]class ScriptTwoTowerModule(torch.jit.ScriptModule): @torch.jit.script_method def set_device(self, device: str): self.right_tensorizer.set_device(device) self.left_tensorizer.set_device(device) @torch.jit.script_method def set_padding_control(self, dimension: str, control: Optional[List[int]]): """ This functions will be called to set a padding style. None - No padding List: first element 0, round seq length to the smallest list element larger than inputs """ self.right_tensorizer.set_padding_control(dimension, control) self.left_tensorizer.set_padding_control(dimension, control)
[docs] def validate(self, export_conf: ExportConfig): deprecation_warning(export_conf)
[docs]class ScriptPyTextTwoTowerModule(ScriptTwoTowerModule): def __init__( self, model: torch.jit.ScriptModule, output_layer: torch.jit.ScriptModule, right_tensorizer: ScriptTensorizer, left_tensorizer: ScriptTensorizer, ): super().__init__() self.model = model self.output_layer = output_layer self.right_tensorizer = right_tensorizer self.left_tensorizer = left_tensorizer @torch.jit.script_method def forward( self, right_texts: Optional[List[str]] = None, left_texts: Optional[List[str]] = None, right_tokens: Optional[List[List[str]]] = None, left_tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), tokens=squeeze_2d(right_tokens), languages=squeeze_1d(languages), ) right_input_tensors = self.right_tensorizer(right_inputs) left_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(left_texts), tokens=squeeze_2d(left_tokens), languages=squeeze_1d(languages), ) left_input_tensors = self.left_tensorizer(left_inputs) logits = self.model(right_input_tensors, left_input_tensors) return self.output_layer(logits)
[docs]class ScriptPyTextTwoTowerModuleWithDense(ScriptPyTextTwoTowerModule): def __init__( self, model: torch.jit.ScriptModule, output_layer: torch.jit.ScriptModule, right_tensorizer: ScriptTensorizer, left_tensorizer: ScriptTensorizer, right_normalizer: VectorNormalizer, left_normalizer: VectorNormalizer, ): super().__init__(model, output_layer, right_tensorizer, left_tensorizer) self.right_normalizer = right_normalizer self.left_normalizer = left_normalizer @torch.jit.script_method def forward( self, right_dense_feat: List[List[float]], left_dense_feat: List[List[float]], right_texts: Optional[List[str]] = None, left_texts: Optional[List[str]] = None, right_tokens: Optional[List[List[str]]] = None, left_tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), tokens=squeeze_2d(right_tokens), languages=squeeze_1d(languages), ) right_input_tensors = self.right_tensorizer(right_inputs) left_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(left_texts), tokens=squeeze_2d(left_tokens), languages=squeeze_1d(languages), ) left_input_tensors = self.left_tensorizer(left_inputs) right_dense_feat = self.right_normalizer.normalize(right_dense_feat) left_dense_feat = self.left_normalizer.normalize(left_dense_feat) right_dense_tensor = torch.tensor(right_dense_feat, dtype=torch.float) left_dense_tensor = torch.tensor(left_dense_feat, dtype=torch.float) if self.right_tensorizer.device != "": right_dense_tensor = right_dense_tensor.to(self.right_tensorizer.device) if self.left_tensorizer.device != "": left_dense_tensor = left_dense_tensor.to(self.left_tensorizer.device) logits = self.model( right_input_tensors, left_input_tensors, right_dense_tensor, left_dense_tensor, ) return self.output_layer(logits)
[docs]class ScriptPyTextTwoTowerEmbeddingModule(ScriptTwoTowerModule): def __init__( self, model: torch.jit.ScriptModule, right_tensorizer: ScriptTensorizer, left_tensorizer: ScriptTensorizer, ): super().__init__() self.model = model self.right_tensorizer = right_tensorizer self.left_tensorizer = left_tensorizer log_class_usage(self.__class__) @torch.jit.script_method def _forward(self, right_inputs: ScriptBatchInput, left_inputs: ScriptBatchInput): right_input_tensors = self.right_tensorizer(right_inputs) left_input_tensors = self.left_tensorizer(left_inputs) return self.model(right_input_tensors, left_input_tensors).cpu() @torch.jit.script_method def forward( self, right_texts: Optional[List[str]] = None, left_texts: Optional[List[str]] = None, right_tokens: Optional[List[List[str]]] = None, left_tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ) -> torch.Tensor: right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), tokens=squeeze_2d(right_tokens), languages=squeeze_1d(languages), ) left_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(left_texts), tokens=squeeze_2d(left_tokens), languages=squeeze_1d(languages), ) return self._forward(right_inputs, left_inputs) @torch.jit.script_method def make_prediction( self, batch: List[ Tuple[ Optional[List[str]], # right_texts Optional[List[str]], # left_texts Optional[List[List[str]]], # right_tokens Optional[List[List[str]]], # left_tokens Optional[List[str]], # languages Optional[List[List[float]]], # right_dense_feat Optional[List[List[float]]], # left_dense_feat ] ], ) -> List[torch.Tensor]: argno = -1 if argno == -1: raise RuntimeError("Argument number not specified during export.") batchsize = len(batch) # Argument types and Tuple indices TEXTS = 0 # MULTI_TEXTS = 1 # TOKENS = 2 # LANGUAGES = 3 # DENSE_FEAT = 4 client_batch: List[int] = [] # res_list: List[torch.Tensor] = [] if argno == TEXTS: flat_right_texts: List[str] = [] flat_left_texts: List[str] = [] for i in range(batchsize): batch_right_element = batch[i][0] batch_left_element = batch[i][1] if batch_right_element is not None: flat_right_texts.extend(batch_right_element) client_batch.append(len(batch_right_element)) else: # At present, we abort the entire batch if # any batch element is malformed. # # Possible refinement: # we can skip malformed requests, # and return a list plus an indiction that one or more # batch elements (and which ones) were malformed raise RuntimeError("Malformed request.") if batch_left_element is not None: flat_left_texts.extend(batch_left_element) else: raise RuntimeError("Malformed request.") flat_result: torch.Tensor = self.forward( right_texts=flat_right_texts, left_texts=flat_left_texts, right_tokens=None, left_tokens=None, languages=None, right_dense_feat=None, left_dense_feat=None, ) else: raise RuntimeError("Parameter type unsupported") # destructure flat result tensor combining # cross-request batches and client side # batches into a cross-request list of # client-side batch tensors return destructure_tensor(client_batch, flat_result) @torch.jit.script_method def make_batch( self, mega_batch: List[ Tuple[ Optional[List[str]], # right_texts Optional[List[str]], # left_texts Optional[List[List[str]]], # right_tokens Optional[List[List[str]]], # left_tokens Optional[List[str]], # languages Optional[List[List[float]]], # right_dense_feat Optional[List[List[float]]], # left_dense_feat int, ] ], goals: Dict[str, str], ) -> List[ List[ Tuple[ Optional[List[str]], # right_texts Optional[List[str]], # left_texts Optional[List[List[str]]], # right_tokens Optional[List[List[str]]], # left_tokens Optional[List[str]], # languages Optional[List[List[float]]], # right_dense_feat Optional[List[List[float]]], # left_dense_feat int, ] ] ]: argno = -1 if argno == -1: raise RuntimeError("Argument number not specified during export.") # The next lines sort all cross-request batch elements by the token length of right_. # Note that cross-request batch element can in turn be a client batch. mega_batch_key_list = [ (max_tokens(self.right_tensorizer.tokenize(x[0], x[2])), n) for (n, x) in enumerate(mega_batch) ] sorted_mega_batch_key_list = sorted(mega_batch_key_list) sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] # TBD: allow model server to specify batch size in goals dictionary max_bs: int = 10 len_mb = len(mega_batch) num_batches = (len_mb + max_bs - 1) // max_bs batch_list: List[ List[ Tuple[ Optional[List[str]], # right_texts Optional[List[str]], # left_texts Optional[List[List[str]]], # right_tokens Optional[List[List[str]]], # left_tokens Optional[List[str]], # languages Optional[List[List[float]]], # right_dense_feat Optional[List[List[float]]], # left_dense_feat int, # position ] ] ] = [] start = 0 for _i in range(num_batches): end = min(start + max_bs, len_mb) batch_list.append(sorted_mega_batch[start:end]) start = end return batch_list
[docs]class ScriptPyTextTwoTowerEmbeddingModuleWithDense(ScriptPyTextTwoTowerEmbeddingModule): def __init__( self, model: torch.jit.ScriptModule, right_tensorizer: ScriptTensorizer, left_tensorizer: ScriptTensorizer, right_normalizer: VectorNormalizer, left_normalizer: VectorNormalizer, ): super().__init__(model, right_tensorizer, left_tensorizer) self.right_normalizer = right_normalizer self.left_normalizer = left_normalizer log_class_usage(self.__class__) @torch.jit.script_method def _forward( self, right_inputs: ScriptBatchInput, left_inputs: ScriptBatchInput, right_dense_tensor: torch.Tensor, left_dense_tensor: torch.Tensor, ): right_input_tensors = self.right_tensorizer(right_inputs) left_input_tensors = self.left_tensorizer(left_inputs) if self.right_tensorizer.device != "": right_dense_tensor = right_dense_tensor.to(self.right_tensorizer.device) if self.left_tensorizer.device != "": left_dense_tensor = left_dense_tensor.to(self.left_tensorizer.device) return self.model( right_input_tensors, left_input_tensors, right_dense_tensor, left_dense_tensor, ).cpu() @torch.jit.script_method def forward( self, right_texts: Optional[List[str]] = None, left_texts: Optional[List[str]] = None, right_tokens: Optional[List[List[str]]] = None, left_tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, right_dense_feat: Optional[List[List[float]]] = None, left_dense_feat: Optional[List[List[float]]] = None, ) -> torch.Tensor: if right_dense_feat is None or left_dense_feat is None: raise RuntimeError("Expect dense feature.") right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), tokens=squeeze_2d(right_tokens), languages=squeeze_1d(languages), ) left_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(left_texts), tokens=squeeze_2d(left_tokens), languages=squeeze_1d(languages), ) right_dense_feat = self.right_normalizer.normalize(right_dense_feat) left_dense_feat = self.left_normalizer.normalize(left_dense_feat) right_dense_tensor = torch.tensor(right_dense_feat, dtype=torch.float) left_dense_tensor = torch.tensor(left_dense_feat, dtype=torch.float) sentence_embedding = self._forward( right_inputs, left_inputs, right_dense_tensor, left_dense_tensor ) return sentence_embedding
############################################################################ # # New module hierarchy Pytext* mirrors ScriptPytext* while reflecting # advances in pytext models: # * Integrated tokenization - sole interface is texts which will be tokenized # * Multi-lingual models - no need to specify languages # # All new modules provide: # * Cross-request batching support # * Batch optimization support # * Sequence length and batch size padding for accelerators # # # The inputs and outputs for cross-request batching with make_prediction # are described in this post: # https://fb.workplace.com/groups/401165540538639/permalink/556111271710731/ # # The inputs and outputs for batch optimization with make_batch # are described in this post: # https://fb.workplace.com/groups/401165540538639/permalink/607830233205501/ # ############################################################# # Pytext Classes:
[docs]class PyTextEmbeddingModule(torch.jit.ScriptModule): def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): super().__init__() self.model = model self.tensorizer = tensorizer log_class_usage(self.__class__) @torch.jit.script_method def set_device(self, device: str): self.tensorizer.set_device(device) @torch.jit.script_method def set_padding_control(self, dimension: str, control: Optional[List[int]]): """ This functions will be called to set a padding style. None - No padding List: first element 0, round seq length to the smallest list element larger than inputs """ self.tensorizer.set_padding_control(dimension, control) @torch.jit.script_method def get_max_seq_len(self) -> int: """ This function returns the maximum sequence length for the model, if it is defined, otherwise None. """ if hasattr(self.tensorizer, "max_seq_len"): if self.tensorizer.max_seq_len is not None: return self.tensorizer.max_seq_len raise RuntimeError("max_seq_len not defined") @torch.jit.script_method def get_max_batch_len(self) -> int: """ This function returns the maximum batch length for the model, if it is defined, otherwise -1. """ if hasattr(self.tensorizer, "batch_padding_control"): batch_padding_control = self.tensorizer.batch_padding_control if batch_padding_control is not None: return batch_padding_control[-1] return -1 @torch.jit.script_method def _forward(self, inputs: ScriptBatchInput): input_tensors = self.tensorizer(inputs) return self.model(input_tensors).cpu() @torch.jit.script_method def forward_impl( self, texts: List[str], ) -> torch.Tensor: inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) return self._forward(inputs) @torch.jit.script_method def forward( self, texts: List[str], ) -> torch.Tensor: input_len = len(texts) max_batch = self.get_max_batch_len() if max_batch < 0: max_batch = input_len result = self.forward_impl( texts[:max_batch], ) if input_len > max_batch: texts = texts[max_batch:] while len(texts) > 0: result_extension = self.forward_impl( texts[:max_batch], ) # the result of forward is either a torch.Tensor or a List[Any] if isinstance(result, torch.Tensor): result = torch.cat([result, result_extension], dim=0) else: result.extend(result_extension) texts = texts[max_batch:] return result @torch.jit.script_method def make_prediction( self, batch: List[ Tuple[ List[str], # texts ] ], ) -> List[torch.Tensor]: flat_result: torch.Tensor = self.forward( texts=make_prediction_texts(batch), ) return destructure_tensor([len(be[0]) for be in batch], flat_result) @torch.jit.script_method def make_batch( self, mega_batch: List[ Tuple[ List[str], # texts int, ] ], goals: Dict[str, str], ) -> List[List[Tuple[List[str], int,]]]: # texts batchsize = len(mega_batch) if batchsize == 0: raise RuntimeError("Input batch must have at least 1 batch element") # The next lines sort all cross-request batch elements by the token length. # Note that cross-request batch element can in turn be a client batch. mega_batch_key_list = [ (max_tokens(self.tensorizer.tokenize(x[0], None)), n) for (n, x) in enumerate(mega_batch) ] sorted_mega_batch_key_list = sorted(mega_batch_key_list) sorted_mega_batch = [mega_batch[n] for (_, n) in sorted_mega_batch_key_list] # TBD: allow model server to specify batch size in goals dictionary max_bs: int = 10 len_mb = len(mega_batch) num_batches = (len_mb + max_bs - 1) // max_bs batch_list: List[ List[ Tuple[ List[str], # texts int, # position ] ] ] = [] start = 0 for _i in range(num_batches): end = min(start + max_bs, len_mb) batch_list.append(sorted_mega_batch[start:end]) start = end return batch_list
# PytextLayerModule is a PytextEmbeddingModule with an additional output layer
[docs]class PyTextLayerModule(PyTextEmbeddingModule): def __init__( self, model: torch.jit.ScriptModule, output_layer: torch.jit.ScriptModule, tensorizer: ScriptTensorizer, ): super().__init__(model, tensorizer) self.output_layer = output_layer @torch.jit.script_method def forward_impl(self, texts: List[str]): # logits = super().forward(texts) inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) input_tensors = self.tensorizer(inputs) logits = self.model(input_tensors) # </> logits = super().forward(texts) return self.output_layer(logits)
# PytextEmbeddingModuleIndex is a PytextEmbeddingModule with an additional Index
[docs]class PyTextEmbeddingModuleIndex(PyTextEmbeddingModule): def __init__( self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer, index: int = 0, ): super().__init__(model, tensorizer) self.index = torch.jit.Attribute(index, int) log_class_usage(self.__class__) @torch.jit.script_method def _forward(self, inputs: ScriptBatchInput): input_tensors = self.tensorizer(inputs) return self.model(input_tensors)[self.index].cpu()
# PytextEmbeddingModuleWithDense is a PytextEmbeddingModule with an additional dense_feat
[docs]class PyTextEmbeddingModuleWithDense(PyTextEmbeddingModule): def __init__( self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer, normalizer: VectorNormalizer, concat_dense: bool = False, ): super().__init__(model, tensorizer) self.normalizer = normalizer self.concat_dense: bool = concat_dense log_class_usage(self.__class__) @torch.jit.script_method def _forward(self, inputs: ScriptBatchInput, dense_tensor: torch.Tensor): input_tensors = self.tensorizer(inputs) if self.tensorizer.device != "": dense_tensor = dense_tensor.to(self.tensorizer.device) return self.model(input_tensors, dense_tensor).cpu() @torch.jit.script_method def forward_impl( self, texts: List[str], dense_feat: List[List[float]], ) -> torch.Tensor: inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) # call model dense_feat = self.normalizer.normalize(dense_feat) dense_tensor = torch.tensor(dense_feat, dtype=torch.float) if self.tensorizer.device != "": dense_tensor = dense_tensor.to(self.tensorizer.device) sentence_embedding = self._forward(inputs, dense_tensor) if self.concat_dense: return torch.cat([sentence_embedding, dense_tensor], 1) else: return sentence_embedding @torch.jit.script_method def forward( self, texts: List[str], dense_feat: List[List[float]], ) -> torch.Tensor: input_len = len(texts) max_batch = self.get_max_batch_len() if max_batch < 0: max_batch = input_len result = self.forward_impl(texts[:max_batch], dense_feat[:max_batch]) if input_len > max_batch: texts = texts[max_batch:] dense_feat = dense_feat[max_batch:] while len(texts) > 0: result_extension = self.forward_impl( texts[:max_batch], dense_feat[:max_batch] ) # the result of forward is either a torch.Tensor or a List[Any] if isinstance(result, torch.Tensor): result = torch.cat([result, result_extension], dim=0) else: result.extend(result_extension) texts = texts[max_batch:] dense_feat = dense_feat[max_batch:] return result @torch.jit.script_method def make_prediction( self, batch: List[ Tuple[ List[str], # texts List[List[float]], # dense ] ], ) -> List[torch.Tensor]: flat_texts, flat_dense = make_prediction_texts_dense(batch) flat_result: torch.Tensor = self.forward( texts=flat_texts, dense_feat=flat_dense, ) return destructure_tensor([len(be[0]) for be in batch], flat_result) @torch.jit.script_method def make_batch( self, mega_batch: List[ Tuple[ List[str], # texts List[List[float]], # dense int, ] ], goals: Dict[str, str], ) -> List[List[Tuple[List[str], List[List[float]], int,]]]: # texts # dense return make_batch_texts_dense(self.tensorizer, mega_batch, goals)
# PytextLayerModuleWithDense is a PytextEmbeddingModuleWithDense with an additional output layer
[docs]class PyTextLayerModuleWithDense(PyTextEmbeddingModuleWithDense): def __init__( self, model: torch.jit.ScriptModule, output_layer: torch.jit.ScriptModule, tensorizer: ScriptTensorizer, normalizer: VectorNormalizer, ): super().__init__(model, tensorizer, normalizer) self.output_layer = output_layer log_class_usage(self.__class__) @torch.jit.script_method def forward_impl( self, texts: List[str], dense_feat: List[List[float]], ): # logits = super().forward(texts, dense_feat) inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) input_tensors = self.tensorizer(inputs) dense_feat = self.normalizer.normalize(dense_feat) dense_tensor = torch.tensor(dense_feat, dtype=torch.float) if self.tensorizer.device != "": dense_tensor = dense_tensor.to(self.tensorizer.device) logits = self.model(input_tensors, dense_tensor) # </>logits = super().forward(texts, dense_feat) return self.output_layer(logits)
# PytextEmbeddingModuleWithDenseIndex is a PytextEmbeddingModuleWithDense with an additional Index
[docs]class PyTextEmbeddingModuleWithDenseIndex(PyTextEmbeddingModuleWithDense): def __init__( self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer, normalizer: VectorNormalizer, index: int = 0, concat_dense: bool = True, ): super().__init__(model, tensorizer, normalizer, concat_dense) self.index = torch.jit.Attribute(index, int) log_class_usage(self.__class__) @torch.jit.script_method def _forward(self, inputs: ScriptBatchInput, dense_tensor: torch.Tensor): # return super()._forward(inputs, dense_tensor)[self.index].cpu() input_tensors = self.tensorizer(inputs) if self.tensorizer.device != "": dense_tensor = dense_tensor.to(self.tensorizer.device) return self.model(input_tensors, dense_tensor)[self.index].cpu()
# </> return super()._forward(inputs, dense_tensor)[self.index].cpu()
[docs]class PyTextVariableSizeEmbeddingModule(PyTextEmbeddingModule): """ Assumes model returns a tuple of representations and sequence lengths, then slices each example's representation according to length. Returns a list of tensors. The slicing is easier to do outside a traced model. """ def __init__(self, model: torch.jit.ScriptModule, tensorizer: ScriptTensorizer): super().__init__(model, tensorizer) log_class_usage(self.__class__) @torch.jit.script_method def _forward(self, inputs: ScriptBatchInput): input_tensors = self.tensorizer(inputs) reps, seq_lens = self.model(input_tensors) reps = reps.cpu() seq_lens = seq_lens.cpu() return [reps[i, : seq_lens[i]] for i in range(len(seq_lens))] @torch.jit.script_method def forward_impl(self, texts: List[str]) -> List[torch.Tensor]: inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(texts, None), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) return self._forward(inputs) @torch.jit.script_method def make_prediction( self, batch: List[ Tuple[ List[str], # texts ] ], ) -> List[List[torch.Tensor]]: flat_result: List[torch.Tensor] = self.forward( texts=make_prediction_texts(batch), ) return destructure_tensor_list([len(be[0]) for be in batch], flat_result)
############################################################# # PytextTwoTower Classes: # # mirrors the inheritance order of Pytext modules. # *** please keep order and inheritance structure *** # *** in sync between these two hierarchies *** #
[docs]class PyTextTwoTowerEmbeddingModule(torch.jit.ScriptModule): def __init__( self, model: torch.jit.ScriptModule, right_tensorizer: ScriptTensorizer, left_tensorizer: ScriptTensorizer, ): super().__init__() self.model = model self.right_tensorizer = right_tensorizer self.left_tensorizer = left_tensorizer log_class_usage(self.__class__) @torch.jit.script_method def set_device(self, device: str): self.right_tensorizer.set_device(device) self.left_tensorizer.set_device(device) @torch.jit.script_method def set_padding_control(self, dimension: str, control: Optional[List[int]]): """ This functions will be called to set a padding style. None - No padding List: first element 0, round seq length to the smallest list element larger than inputs """ self.right_tensorizer.set_padding_control(dimension, control) self.left_tensorizer.set_padding_control(dimension, control) @torch.jit.script_method def _forward(self, right_inputs: ScriptBatchInput, left_inputs: ScriptBatchInput): right_input_tensors = self.right_tensorizer(right_inputs) left_input_tensors = self.left_tensorizer(left_inputs) return self.model(right_input_tensors, left_input_tensors).cpu() @torch.jit.script_method def forward( self, right_texts: List[str], left_texts: List[str], ) -> torch.Tensor: right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) left_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(left_texts), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) return self._forward(right_inputs, left_inputs) @torch.jit.script_method def make_prediction( self, batch: List[ Tuple[ List[str], # right_texts List[str], # left_texts ] ], ) -> List[torch.Tensor]: batchsize = len(batch) flat_right_texts: List[str] = [] flat_left_texts: List[str] = [] for i in range(batchsize): batch_right_element = batch[i][0] batch_left_element = batch[i][1] flat_right_texts.extend(batch_right_element) flat_left_texts.extend(batch_left_element) flat_result: torch.Tensor = self.forward( right_texts=flat_right_texts, left_texts=flat_left_texts, ) return destructure_tensor([len(be[0]) for be in batch], flat_result) @torch.jit.script_method def make_batch( self, mega_batch: List[ Tuple[ List[str], # right_texts List[str], # left_texts int, ] ], goals: Dict[str, str], ) -> List[List[Tuple[List[str], List[str], int,]]]: # right_texts # left_texts # The next lines sort all cross-request batch elements by the token length of right_. # Note that cross-request batch element can in turn be a client batch. mega_batch_key_list = [ (max_tokens(self.right_tensorizer.tokenize(x[0], None)), n) for (n, x) in enumerate(mega_batch) ] sorted_mega_batch_key_list = sorted(mega_batch_key_list) sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] # TBD: allow model server to specify batch size in goals dictionary max_bs: int = 10 len_mb = len(mega_batch) num_batches = (len_mb + max_bs - 1) // max_bs batch_list: List[ List[ Tuple[ List[str], # right_texts List[str], # left_texts int, # position ] ] ] = [] start = 0 for _i in range(num_batches): end = min(start + max_bs, len_mb) batch_list.append(sorted_mega_batch[start:end]) start = end return batch_list
[docs]class PyTextTwoTowerLayerModule(PyTextTwoTowerEmbeddingModule): def __init__( self, model: torch.jit.ScriptModule, output_layer: torch.jit.ScriptModule, right_tensorizer: ScriptTensorizer, left_tensorizer: ScriptTensorizer, ): super().__init__(model, right_tensorizer, left_tensorizer) self.output_layer = output_layer @torch.jit.script_method def forward( self, right_texts: List[str], left_texts: List[str], ): logits = super().forward(right_texts, left_texts) return self.output_layer(logits)
[docs]class PyTextTwoTowerEmbeddingModuleWithDense(PyTextTwoTowerEmbeddingModule): def __init__( self, model: torch.jit.ScriptModule, right_tensorizer: ScriptTensorizer, left_tensorizer: ScriptTensorizer, right_normalizer: VectorNormalizer, left_normalizer: VectorNormalizer, ): super().__init__(model, right_tensorizer, left_tensorizer) self.right_normalizer = right_normalizer self.left_normalizer = left_normalizer log_class_usage(self.__class__) @torch.jit.script_method def _forward( self, right_inputs: ScriptBatchInput, left_inputs: ScriptBatchInput, right_dense_tensor: torch.Tensor, left_dense_tensor: torch.Tensor, ): right_input_tensors = self.right_tensorizer(right_inputs) left_input_tensors = self.left_tensorizer(left_inputs) if self.right_tensorizer.device != "": right_dense_tensor = right_dense_tensor.to(self.right_tensorizer.device) if self.left_tensorizer.device != "": left_dense_tensor = left_dense_tensor.to(self.left_tensorizer.device) return self.model( right_input_tensors, left_input_tensors, right_dense_tensor, left_dense_tensor, ).cpu() @torch.jit.script_method def forward( self, right_texts: List[str], left_texts: List[str], right_dense_feat: List[List[float]], left_dense_feat: List[List[float]], ) -> torch.Tensor: right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) left_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(left_texts), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) right_dense_feat = self.right_normalizer.normalize(right_dense_feat) left_dense_feat = self.left_normalizer.normalize(left_dense_feat) right_dense_tensor = torch.tensor(right_dense_feat, dtype=torch.float) left_dense_tensor = torch.tensor(left_dense_feat, dtype=torch.float) sentence_embedding = self._forward( right_inputs, left_inputs, right_dense_tensor, left_dense_tensor ) return sentence_embedding
[docs]class PyTextTwoTowerLayerModuleWithDense(PyTextTwoTowerLayerModule): def __init__( self, model: torch.jit.ScriptModule, output_layer: torch.jit.ScriptModule, right_tensorizer: ScriptTensorizer, left_tensorizer: ScriptTensorizer, right_normalizer: VectorNormalizer, left_normalizer: VectorNormalizer, ): super().__init__(model, output_layer, right_tensorizer, left_tensorizer) self.right_normalizer = right_normalizer self.left_normalizer = left_normalizer @torch.jit.script_method def forward( self, right_texts: List[str], left_texts: List[str], right_dense_feat: List[List[float]], left_dense_feat: List[List[float]], ): right_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(right_texts), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) right_input_tensors = self.right_tensorizer(right_inputs) left_inputs: ScriptBatchInput = ScriptBatchInput( texts=resolve_texts(left_texts), tokens=squeeze_2d(None), languages=squeeze_1d(None), ) left_input_tensors = self.left_tensorizer(left_inputs) right_dense_feat = self.right_normalizer.normalize(right_dense_feat) left_dense_feat = self.left_normalizer.normalize(left_dense_feat) right_dense_tensor = torch.tensor(right_dense_feat, dtype=torch.float) left_dense_tensor = torch.tensor(left_dense_feat, dtype=torch.float) if self.right_tensorizer.device != "": right_dense_tensor = right_dense_tensor.to(self.right_tensorizer.device) if self.left_tensorizer.device != "": left_dense_tensor = left_dense_tensor.to(self.left_tensorizer.device) logits = self.model( right_input_tensors, left_input_tensors, right_dense_tensor, left_dense_tensor, ) return self.output_layer(logits) @torch.jit.script_method def make_prediction( self, batch: List[ Tuple[ List[str], # right_texts List[str], # left_texts List[List[float]], # right_dense_feat List[List[float]], # left_dense_feat ] ], ) -> List[torch.Tensor]: batchsize = len(batch) flat_right_texts: List[str] = [] flat_left_texts: List[str] = [] flat_right_dense: List[List[float]] = [] flat_left_dense: List[List[float]] = [] for i in range(batchsize): batch_right_element = batch[i][0] batch_left_element = batch[i][1] batch_right_dense_element = batch[i][2] batch_left_dense_element = batch[i][3] flat_right_texts.extend(batch_right_element) flat_left_texts.extend(batch_left_element) flat_right_dense.extend(batch_right_dense_element) flat_left_dense.extend(batch_left_dense_element) flat_result: torch.Tensor = self.forward( right_texts=flat_right_texts, left_texts=flat_left_texts, right_dense_feat=flat_right_dense, left_dense_feat=flat_left_dense, ) return destructure_tensor([len(be[0]) for be in batch], flat_result) @torch.jit.script_method def make_batch( self, mega_batch: List[ Tuple[ List[str], # right_texts List[str], # left_texts List[List[float]], # right_dense_feat List[List[float]], # left_dense_feat int, ] ], goals: Dict[str, str], ) -> List[ List[ Tuple[ List[str], # right_texts List[str], # left_texts List[List[float]], # right_dense_feat List[List[float]], # left_dense_feat int, ] ] ]: # The next lines sort all cross-request batch elements by the token length of right_. # Note that cross-request batch element can in turn be a client batch. mega_batch_key_list = [ (max_tokens(self.right_tensorizer.tokenize(x[0], None)), n) for (n, x) in enumerate(mega_batch) ] sorted_mega_batch_key_list = sorted(mega_batch_key_list) sorted_mega_batch = [mega_batch[n] for (key, n) in sorted_mega_batch_key_list] # TBD: allow model server to specify batch size in goals dictionary max_bs: int = 10 len_mb = len(mega_batch) num_batches = (len_mb + max_bs - 1) // max_bs batch_list: List[ List[ Tuple[ List[str], # right_texts List[str], # left_texts List[List[float]], # right_dense_feat List[List[float]], # left_dense_feat int, # position ] ] ] = [] start = 0 for _i in range(num_batches): end = min(start + max_bs, len_mb) batch_list.append(sorted_mega_batch[start:end]) start = end return batch_list