Source code for pytext.models.representations.huggingface_bert_sentence_encoder

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import os
from typing import List, Tuple

import torch
from pytext.config import ConfigBase
from pytext.models.representations.transformer_sentence_encoder_base import (
    TransformerSentenceEncoderBase,
)
from pytext.utils.file_io import PathManager
from pytext.utils.usage import log_class_usage
from transformers.modeling_bert import BertConfig, BertModel


[docs]class HuggingFaceBertSentenceEncoder(TransformerSentenceEncoderBase): """ Generate sentence representation using the open source HuggingFace BERT model. This class implements loading the model weights from a pre-trained model file. """
[docs] class Config(TransformerSentenceEncoderBase.Config, ConfigBase): bert_cpt_dir: str = ( "manifold://nlp_technologies/tree/huggingface-models/bert-base-uncased/" ) load_weights: bool = True
def __init__( self, config: Config, output_encoded_layers: bool, *args, **kwargs ) -> None: super().__init__(config, output_encoded_layers=output_encoded_layers) # Load config config_file = os.path.join(config.bert_cpt_dir, "config.json") local_config_path = PathManager.get_local_path(config_file) bert_config = BertConfig.from_json_file(local_config_path) print("Bert model config {}".format(bert_config)) # Instantiate model. model = BertModel(bert_config) weights_path = os.path.join(config.bert_cpt_dir, "pytorch_model.bin") # load pre-trained weights if weights_path exists if config.load_weights and PathManager.isfile(weights_path): with PathManager.open(weights_path, "rb") as fd: state_dict = torch.load(fd) missing_keys: List[str] = [] unexpected_keys: List[str] = [] error_msgs: List[str] = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) for key in list(state_dict.keys()): new_key = None if key.endswith("LayerNorm.gamma"): # compatibility with v0.5 models new_key = key.replace("LayerNorm.gamma", "LayerNorm.weight") if key.endswith("LayerNorm.beta"): # compatibility with v0.5 models new_key = key.replace("LayerNorm.beta", "LayerNorm.bias") if new_key is not None: state_dict[new_key] = state_dict.pop(key) if metadata is not None: state_dict._metadata = metadata def load(module, prefix=""): local_metadata = ( {} if metadata is None else metadata.get(prefix[:-1], {}) ) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(model, prefix="" if hasattr(model, "bert") else "bert.") if len(missing_keys) > 0: print( "Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys ) ) if len(unexpected_keys) > 0: print( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys ) ) self.bert = model log_class_usage(__class__) def _encoder(self, input_tuple: Tuple[torch.Tensor, ...]): tokens, pad_mask, segment_labels, _ = input_tuple last_encoded_layers, pooled_output, encoded_layers = self.bert( tokens, attention_mask=pad_mask, token_type_ids=segment_labels, output_hidden_states=True, ) return encoded_layers, pooled_output def _embedding(self): # used to tie weights in MaskedLM model return self.bert.embeddings.word_embeddings