Source code for pytext.models.output_layers.distance_output_layer

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from enum import IntEnum, unique
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from pytext.config.component import create_loss
from pytext.data.utils import Vocabulary
from pytext.fields import FieldMeta
from pytext.loss import (
    BinaryCrossEntropyLoss,
    CosineEmbeddingLoss,
    MAELoss,
    MSELoss,
    NLLLoss,
)
from pytext.models.output_layers.output_layer_base import OutputLayerBase
from pytext.utils.label import get_label_weights
from pytext.utils.usage import log_class_usage


[docs]@unique class OutputScore(IntEnum): raw_cosine = 1 norm_cosine = 2 sigmoid_cosine = 3
[docs]class PairwiseCosineDistanceOutputLayer(OutputLayerBase): __EXPANSIBLE__ = True
[docs] class Config(OutputLayerBase.Config): loss: Union[ BinaryCrossEntropyLoss.Config, CosineEmbeddingLoss.Config, MAELoss.Config, MSELoss.Config, NLLLoss.Config, ] = CosineEmbeddingLoss.Config() score_threshold: float = 0.9 score_type: OutputScore = OutputScore.norm_cosine label_weights: Optional[Dict[str, float]] = None
[docs] @classmethod def from_config( cls, config, metadata: Optional[FieldMeta] = None, labels: Optional[Vocabulary] = None, ): label_weights = ( get_label_weights(labels.idx, config.label_weights) if config.label_weights else None ) assert ( config.score_type == OutputScore.raw_cosine or config.score_type == OutputScore.norm_cosine or config.score_type == OutputScore.sigmoid_cosine ), f"Invalid score_type {config.score_type}. See OutputScore enum." return cls( list(labels) if labels is not None else None, create_loss(config.loss, weight=label_weights), config.score_threshold, config.score_type, )
def __init__( self, target_names: Optional[List[str]] = None, loss_fn: Union[ BinaryCrossEntropyLoss, CosineEmbeddingLoss, MAELoss, MSELoss, NLLLoss ] = None, score_threshold: bool = Config.score_threshold, score_type: OutputScore = Config.score_type, ): super().__init__(target_names, loss_fn) self.score_threshold = score_threshold self.score_type = score_type log_class_usage(__class__)
[docs] def get_loss( self, logits: Tuple[torch.Tensor, torch.Tensor], targets: torch.Tensor, context: Optional[Dict[str, Any]] = None, reduce: bool = True, ) -> torch.Tensor: def _transform_logits(logits): if isinstance(self.loss_fn, CosineEmbeddingLoss): return logits elif isinstance(self.loss_fn, NLLLoss): # log probability shall be returned in this case # the factor 2.3 = log(10) is used to make sure the consistency between matric reporting and loss function cosine_sim_scores = F.cosine_similarity(logits[0], logits[1], dim=1) pos_scores = cosine_sim_scores * 2.0 * 2.3 neg_scores = (1.0 - cosine_sim_scores) * 2.0 * 2.3 return F.log_softmax( torch.stack((neg_scores, pos_scores), dim=1), dim=1 ) else: return F.cosine_similarity(logits[0], logits[1], dim=1) def _transform_targets(targets): if isinstance(self.loss_fn, (BinaryCrossEntropyLoss, NLLLoss)): return targets if isinstance(self.loss_fn, (MAELoss, MSELoss)): return targets.to(dtype=torch.float) # Replace label = 0 with -1 because we're using cosine_embedding_loss. targets = targets.masked_fill(mask=(targets == 0), value=-1.0) return targets.to(dtype=torch.float) logits = _transform_logits(logits) targets = _transform_targets(targets) return super().get_loss(logits, targets, context, reduce)
[docs] def get_pred(self, logits: torch.Tensor, targets: torch.Tensor, *args, **kwargs): cosine_sim_scores = F.cosine_similarity(logits[0], logits[1], dim=1) if self.score_type == OutputScore.raw_cosine: preds = (cosine_sim_scores >= self.score_threshold).to(dtype=torch.long) # Since metric reporting depends on returning a distribution over labels, # we will fake a distribution over two labels. We will insert the distance # at pred_index (pred_index = 0 or 1) in the scores tensor. scores = torch.zeros(logits[0].size(0), 2, device=logits[0].device) scores[:, 0] = (1.0 - cosine_sim_scores[:]) * 2.0 scores[:, 1] = cosine_sim_scores[:] * 2.0 else: pos_scores, neg_scores = ( get_norm_cosine_scores(cosine_sim_scores) if self.score_type == OutputScore.norm_cosine else get_sigmoid_scores(cosine_sim_scores) ) preds = (pos_scores >= self.score_threshold).to(dtype=torch.long) scores = torch.cat( (neg_scores.unsqueeze(1), pos_scores.unsqueeze(1)), dim=1 ) return preds, scores
[docs]class DenseRetrievalOutputLayer(PairwiseCosineDistanceOutputLayer):
[docs] def get_loss( self, logits: Tuple[torch.Tensor, torch.Tensor], targets: torch.Tensor, context: Optional[Dict[str, Any]] = None, reduce: bool = True, ) -> torch.Tensor: # Loss computation pointer: https://fburl.com/inf9ra38 log_probs = self._get_log_probs(logits) # only supports NLL loss loss = self.loss_fn(log_probs, targets, reduce) return loss
[docs] def get_pred(self, logits: torch.Tensor, targets: torch.Tensor, *args, **kwargs): log_probs = self._get_log_probs(logits) _, preds = torch.max(log_probs, 1) # Pack in logits and positive_indices_per_question for computing Avg Rank question_logits, context_logits = logits return ( ( preds, question_logits.detach().cpu().tolist(), context_logits.detach().cpu().tolist(), ), log_probs, ) # Expected tuple: prediction, scores
def _get_log_probs(self, logits): # question_logits: (bsz X rep_dim); context_logits: (bsz * 2, rep_dim) question_logits, context_logits = logits dot_products = torch.matmul( question_logits, torch.transpose(context_logits, 0, 1) ) # (bsz X bsz*num_negs+1) return F.log_softmax(dot_products, dim=1)
[docs]def get_norm_cosine_scores(cosine_sim_scores): # Min-max normalization. It's monotonic in nature and hence doesn't change # score distribution. pos_scores = (cosine_sim_scores + 1.0) / 2.0 neg_scores = 1.0 - pos_scores return pos_scores, neg_scores
[docs]def get_sigmoid_scores(cosine_sim_scores): pos_scores = torch.sigmoid(cosine_sim_scores) neg_scores = 1.0 - pos_scores return pos_scores, neg_scores