Source code for pytext.models.representations.traced_transformer_encoder

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Tuple

import torch
import torch.cuda
import torch.nn as nn
from fairseq.modules import (
    TransformerSentenceEncoder as TransformerSentenceEncoderModule,
)
from pytext.utils.usage import log_class_usage


# Wrapper for TransformerSentenceEncoder to enable tracing
[docs]class TraceableTransformerWrapper(nn.Module): def __init__(self, eager_encoder: TransformerSentenceEncoderModule) -> None: super().__init__() assert hasattr(eager_encoder, "traceable") assert eager_encoder.traceable self.eager_encoder = eager_encoder log_class_usage(__class__)
[docs] def forward( self, tokens: torch.Tensor, segment_labels: torch.Tensor = None, positions: torch.Tensor = None, token_embeddings: torch.Tensor = None, attn_mask: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: return self.eager_encoder( tokens, segment_labels, positions=positions, token_embeddings=token_embeddings, attn_mask=attn_mask, )
[docs]class TracedTransformerEncoder(nn.Module): def __init__( self, eager_encoder: TransformerSentenceEncoderModule, tokens: torch.Tensor, segment_labels: torch.Tensor = None, positions: torch.Tensor = None, token_embeddings: torch.Tensor = None, attn_mask: torch.Tensor = None, ) -> None: super().__init__() traceable_encoder = TraceableTransformerWrapper(eager_encoder) traced_encoder_inputs = self._prepare_inputs( tokens, segment_labels, positions, token_embeddings, attn_mask ) self.has_segment_labels = segment_labels is not None self.has_positions = positions is not None self.iter_ = 0 # do not check trace because of non-deterministic ops (e.g. dropout) self.traced_encoder = torch.jit.trace( traceable_encoder, tuple(traced_encoder_inputs), check_trace=False ) if torch.cuda.is_available(): try: import torch_tvm torch_tvm.enable( device_type="gpu", device="cuda", device_id=torch.cuda.current_device(), is_training=True, ) print("Using TVM in traced transformer") except ImportError: print("Not using TVM in traced transformer") log_class_usage(__class__)
[docs] def forward( self, tokens: torch.Tensor, segment_labels: torch.Tensor = None, positions: torch.Tensor = None, token_embeddings: torch.Tensor = None, attn_mask: torch.Tensor = None, ): assert self.has_segment_labels == (segment_labels is not None) assert self.has_positions == (positions is not None) traced_encoder_inputs = self._prepare_inputs( tokens, segment_labels, positions, token_embeddings, attn_mask ) self.iter_ += 1 if self.iter_ % 100 == 0: print("Iter: ", self.iter_) with torch.autograd.profiler.profile( enabled=True, use_cuda=True, record_shapes=True ) as prof: encoded_layers, pooled_output = self.traced_encoder( *traced_encoder_inputs ) print( prof.key_averages(group_by_input_shape=True).table(sort_by="cuda_time") ) else: encoded_layers, pooled_output = self.traced_encoder(*traced_encoder_inputs) encoded_layers = list(torch.unbind(encoded_layers)) return encoded_layers, pooled_output
def _prepare_inputs( self, tokens: torch.Tensor, segment_labels: torch.Tensor = None, positions: torch.Tensor = None, token_embeddings: torch.Tensor = None, attn_mask: torch.Tensor = None, ): inputs = [tokens] if segment_labels is not None: inputs += [segment_labels] if positions is not None: inputs += [positions] if token_embeddings is not None: inputs += [token_embeddings] if attn_mask is not None: inputs += [attn_mask] return inputs