Source code for pytext.models.ensembles.bagging_doc_ensemble

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List

import torch
from pytext.models.doc_model import DocModel

from .ensemble import EnsembleModel


[docs]class BaggingDocEnsembleModel(EnsembleModel): """Ensemble class that uses bagging for ensembling document classification models. """
[docs] class Config(EnsembleModel.Config): """Configuration class for `NewBaggingDocEnsemble`. These attributes are used by `Ensemble.from_config()` to construct instance of `NewBaggingDocEnsemble`. Attributes: models (List[NewDocModel.Config]): List of document classification model configurations. """ models: List[DocModel.Config]
[docs] def forward(self, *args, **kwargs) -> torch.Tensor: """Call `forward()` method of each document classification sub-model by passing all arguments and named arguments to the sub-models, collect the logits from them and average their values. Returns: torch.Tensor: Logits from the ensemble. """ logit_d_list = torch.cat( tuple(model.forward(*args, **kwargs).unsqueeze(2) for model in self.models), dim=2, ) return torch.mean(logit_d_list, dim=2)