Source code for pytext.models.representations.slot_attention

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytext.config import ConfigBase
from pytext.config.module_config import SlotAttentionType
from pytext.models.module import Module
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence


[docs]class SlotAttention(Module):
[docs] class Config(ConfigBase): attn_dimension: int = 64 attention_type: SlotAttentionType = SlotAttentionType.NO_ATTENTION
def __init__(self, config: Config, n_input: int, batch_first: bool = True) -> None: super().__init__() self.batch_first = batch_first self.attention_type = config.attention_type # attention can be in the form of h1'Wh2 ("multiply"), # g(h1;h2) ("add") or h1'h2 ("dot") if self.attention_type == SlotAttentionType.CONCAT: self.attention_add = nn.Sequential( nn.Linear(2 * n_input, config.attn_dimension, bias=False), nn.Tanh(), nn.Linear(config.attn_dimension, 1, bias=False), ) elif self.attention_type == SlotAttentionType.MULTIPLY: self.attention_mult = nn.Linear(n_input, n_input, bias=False)
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor: if isinstance(inputs, PackedSequence): inputs, lengths = pad_packed_sequence(inputs, batch_first=self.batch_first) # inputs -> bsz * num_words * dim size = inputs.size() # Tiling the full input on top of itself size[1] times exp_inputs_2 = inputs.unsqueeze(1).expand(size[0], size[1], size[1], size[2]) if self.attention_type == SlotAttentionType.CONCAT: # Tiling each row on top of itself size[1] times exp_inputs_1 = inputs.unsqueeze(2).expand( size[0], size[1], size[1], size[2] ) catted = torch.cat((exp_inputs_1, exp_inputs_2), 3) attn_weights_add = F.softmax( self.attention_add(catted).squeeze(3), dim=2 ).unsqueeze(2) context_add = torch.matmul(attn_weights_add, exp_inputs_2).squeeze(2) output = torch.cat((inputs, context_add), 2) elif ( self.attention_type == SlotAttentionType.MULTIPLY or self.attention_type == SlotAttentionType.DOT ): attended = ( inputs if self.attention_type == SlotAttentionType.DOT else self.attention_mult(inputs) ) attn_weights_mult = F.softmax( torch.matmul(inputs, torch.transpose(attended, 1, 2)), dim=2 ).unsqueeze(2) context_mult = torch.matmul(attn_weights_mult, exp_inputs_2).squeeze(2) output = torch.cat((inputs, context_mult), 2) else: output = inputs return output