#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Tuple, Dict
import torch
from pytext.models.joint_model import IntentSlotModel
from pytext.models.model import Model
from pytext.models.output_layers import CRFOutputLayer
from pytext.utils.usage import log_class_usage
from .ensemble import EnsembleModel
[docs]class BaggingIntentSlotEnsembleModel(EnsembleModel):
"""Ensemble class that uses bagging for ensembling intent-slot models.
Args:
config (Config): Configuration object specifying all the
parameters of BaggingIntentSlotEnsemble.
models (List[Model]): List of intent-slot model objects.
Attributes:
use_crf (bool): Whether to use CRF for word tagging task.
output_layer (IntentSlotOutputLayer): Output layer of intent-slot
model responsible for computing loss and predictions.
"""
[docs] class Config(EnsembleModel.Config):
"""Configuration class for `BaggingIntentSlotEnsemble`.
These attributes are used by `Ensemble.from_config()` to construct
instance of `BaggingIntentSlotEnsemble`.
Attributes:
models (List[IntentSlotModel.Config]): List of intent-slot model
configurations.
output_layer (IntentSlotOutputLayer): Output layer of intent-slot
model responsible for computing loss and predictions.
"""
models: List[IntentSlotModel.Config]
use_crf: bool = False
def __init__(self, config: Config, models: List[Model], *args, **kwargs) -> None:
super().__init__(config, models)
self.use_crf = isinstance(self.output_layer.word_output, CRFOutputLayer)
log_class_usage(__class__)
[docs] def merge_sub_models(self) -> None:
"""Merges all sub-models' transition matrices when using CRF.
Otherwise does nothing.
"""
# to get the transition_matrix for the ensemble model, we average the
# transition matrices of the children model
if not self.use_crf:
return
transition_matrix = torch.mean(
torch.cat(
tuple(
model.output_layer.word_output.crf.get_transitions().unsqueeze(0)
for model in self.models
),
dim=0,
),
dim=0,
)
self.output_layer.word_output.crf.set_transitions(transition_matrix)
[docs] def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""Call `forward()` method of each intent-slot sub-model by passing all
arguments and named arguments to the sub-models, collect the logits from
them and average their values.
Returns:
torch.Tensor: Logits from the ensemble.
"""
logit_d_list, logit_w_list = None, None
for model in self.models:
logit_d, logit_w = model.forward(*args, **kwargs)
logit_d, logit_w = logit_d.unsqueeze(2), logit_w.unsqueeze(3)
if logit_d_list is None:
logit_d_list = logit_d
else:
logit_d_list = torch.cat([logit_d_list, logit_d], dim=2)
if logit_w_list is None:
logit_w_list = logit_w
else:
logit_w_list = torch.cat([logit_w_list, logit_w], dim=3)
return torch.mean(logit_d_list, dim=2), torch.mean(logit_w_list, dim=3)
[docs] def torchscriptify(self, tensorizers, traced_model):
return self.models[0].torchscriptify(
tensorizers,
traced_model,
merged_output_layer=self.output_layer if self.use_crf else None,
)
[docs] def load_state_dict(
self,
state_dict: Dict[str, torch.Tensor],
strict: bool = True,
):
super().load_state_dict(state_dict=state_dict, strict=strict)
for i, m in enumerate(self.models):
submodel_state_dict = {}
for key, val in state_dict.items():
split_key = key.split(".")
if split_key[0] == "models" and int(split_key[1]) == i:
submodel_state_dict[".".join(split_key[2:])] = val
m.load_state_dict(state_dict=submodel_state_dict, strict=strict)