Source code for pytext.models.ensembles.bagging_intent_slot_ensemble

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Tuple, Dict

import torch
from pytext.models.joint_model import IntentSlotModel
from pytext.models.model import Model
from pytext.models.output_layers import CRFOutputLayer
from pytext.utils.usage import log_class_usage

from .ensemble import EnsembleModel


[docs]class BaggingIntentSlotEnsembleModel(EnsembleModel): """Ensemble class that uses bagging for ensembling intent-slot models. Args: config (Config): Configuration object specifying all the parameters of BaggingIntentSlotEnsemble. models (List[Model]): List of intent-slot model objects. Attributes: use_crf (bool): Whether to use CRF for word tagging task. output_layer (IntentSlotOutputLayer): Output layer of intent-slot model responsible for computing loss and predictions. """
[docs] class Config(EnsembleModel.Config): """Configuration class for `BaggingIntentSlotEnsemble`. These attributes are used by `Ensemble.from_config()` to construct instance of `BaggingIntentSlotEnsemble`. Attributes: models (List[IntentSlotModel.Config]): List of intent-slot model configurations. output_layer (IntentSlotOutputLayer): Output layer of intent-slot model responsible for computing loss and predictions. """ models: List[IntentSlotModel.Config] use_crf: bool = False
def __init__(self, config: Config, models: List[Model], *args, **kwargs) -> None: super().__init__(config, models) self.use_crf = isinstance(self.output_layer.word_output, CRFOutputLayer) log_class_usage(__class__)
[docs] def merge_sub_models(self) -> None: """Merges all sub-models' transition matrices when using CRF. Otherwise does nothing. """ # to get the transition_matrix for the ensemble model, we average the # transition matrices of the children model if not self.use_crf: return transition_matrix = torch.mean( torch.cat( tuple( model.output_layer.word_output.crf.get_transitions().unsqueeze(0) for model in self.models ), dim=0, ), dim=0, ) self.output_layer.word_output.crf.set_transitions(transition_matrix)
[docs] def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """Call `forward()` method of each intent-slot sub-model by passing all arguments and named arguments to the sub-models, collect the logits from them and average their values. Returns: torch.Tensor: Logits from the ensemble. """ logit_d_list, logit_w_list = None, None for model in self.models: logit_d, logit_w = model.forward(*args, **kwargs) logit_d, logit_w = logit_d.unsqueeze(2), logit_w.unsqueeze(3) if logit_d_list is None: logit_d_list = logit_d else: logit_d_list = torch.cat([logit_d_list, logit_d], dim=2) if logit_w_list is None: logit_w_list = logit_w else: logit_w_list = torch.cat([logit_w_list, logit_w], dim=3) return torch.mean(logit_d_list, dim=2), torch.mean(logit_w_list, dim=3)
[docs] def torchscriptify(self, tensorizers, traced_model): return self.models[0].torchscriptify( tensorizers, traced_model, merged_output_layer=self.output_layer if self.use_crf else None, )
[docs] def load_state_dict( self, state_dict: Dict[str, torch.Tensor], strict: bool = True, ): super().load_state_dict(state_dict=state_dict, strict=strict) for i, m in enumerate(self.models): submodel_state_dict = {} for key, val in state_dict.items(): split_key = key.split(".") if split_key[0] == "models" and int(split_key[1]) == i: submodel_state_dict[".".join(split_key[2:])] = val m.load_state_dict(state_dict=submodel_state_dict, strict=strict)