Source code for pytext.models.representations.contextual_intent_slot_rep

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List, Tuple, Union

import torch
from pytext.models.module import create_module
from pytext.utils.usage import log_class_usage

from .bilstm_doc_slot_attention import BiLSTMDocSlotAttention
from .docnn import DocNNRepresentation
from .jointcnn_rep import JointCNNRepresentation
from .representation_base import RepresentationBase


[docs]class ContextualIntentSlotRepresentation(RepresentationBase): """ Representation for a contextual intent slot model The inputs are two embeddings: word level embedding containing dictionary features, sequence (contexts) level embedding. See following diagram for the representation implementation that combines the two embeddings. Seq_representation is concatenated with word_embeddings. :: +-----------+ | word_embed|--------------------------->+ +--------------------+ +-----------+ | | doc_representation | +-----------+ +-------------------+ |-->+--------------------+ | seq_embed |-->| seq_representation|--->+ | word_representation| +-----------+ +-------------------+ +--------------------+ joint_representation """
[docs] class Config(RepresentationBase.Config): sen_representation: DocNNRepresentation.Config = DocNNRepresentation.Config() seq_representation: DocNNRepresentation.Config = DocNNRepresentation.Config() joint_representation: Union[ BiLSTMDocSlotAttention.Config, JointCNNRepresentation.Config ] = BiLSTMDocSlotAttention.Config()
def __init__(self, config: Config, embed_dim: Tuple[int, ...]) -> None: super().__init__(config) assert len(embed_dim) == 2 self.sen_rep = create_module(config.sen_representation, embed_dim=embed_dim[1]) self.sen_representation_dim = self.sen_rep.representation_dim self.seq_rep = create_module( config.seq_representation, embed_dim=self.sen_representation_dim ) self.seq_representation_dim = self.seq_rep.representation_dim self.joint_rep = create_module( config.joint_representation, embed_dim=embed_dim[0] + self.seq_representation_dim, ) self.doc_representation_dim = self.joint_rep.doc_representation_dim self.word_representation_dim = self.joint_rep.word_representation_dim log_class_usage(__class__)
[docs] def forward( self, word_seq_embed: Tuple[torch.Tensor, torch.Tensor], word_lengths: torch.Tensor, seq_lengths: torch.Tensor, *args, ) -> List[torch.Tensor]: (word_embed, seq_embed) = word_seq_embed (bsz, max_num_sen, max_seq_len, dim) = seq_embed.size() rep = self.sen_rep(seq_embed.view(bsz * max_num_sen, max_seq_len, dim)) sentence_reps = rep.view(bsz, max_num_sen, self.sen_representation_dim) seq_out = self.seq_rep(embedded_tokens=sentence_reps) bsz, max_seq_len, dim = word_embed.size() seq_rep_expand = seq_out.view(bsz, 1, -1).repeat(1, max_seq_len, 1) new_embed = torch.cat([seq_rep_expand, word_embed], 2) return self.joint_rep(new_embed, word_lengths)