Source code for pytext.models.masking_utils

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from enum import Enum

import numpy as np
import torch


[docs]class MaskingStrategy(Enum): RANDOM = "random" FREQUENCY = "frequency_based"
[docs]def random_masking(tokens: torch.tensor, mask_prob: float) -> torch.Tensor: """ Function to mask tokens randomly. Inputs: 1) tokens: Tensor with token ids of shape (batch_size x seq_len) 2) mask_prob: Probability of masking a particular token Outputs: mask: Tensor with same shape as input tokens (batch_size x seq_len) with masked tokens represented by a 1 and everything else as 0. """ batch_size, seq_len = tokens.size() num_masked_per_seq = int(seq_len * mask_prob) mask = np.zeros((batch_size, seq_len), dtype=np.int_) mask[:, :num_masked_per_seq] = 1 for row in mask: np.random.shuffle(row) mask = torch.from_numpy(mask).to(tokens.device) return mask
[docs]def frequency_based_masking( tokens: torch.tensor, token_sampling_weights: np.ndarray, mask_prob: float ) -> torch.Tensor: """ Function to mask tokens based on frequency. Inputs: 1) tokens: Tensor with token ids of shape (batch_size x seq_len) 2) token_sampling_weights: numpy array with shape (batch_size x seq_len) and each element representing the sampling weight assicated with the corresponding token in tokens 3) mask_prob: Probability of masking a particular token Outputs: mask: Tensor with same shape as input tokens (batch_size x seq_len) with masked tokens represented by a 1 and everything else as 0. """ batch_size, seq_len = tokens.size() num_masked_per_batch = int(batch_size * seq_len * mask_prob) indices = tokens.cpu().numpy().flatten() # get the weights associated with each token weights = np.take(token_sampling_weights, indices) # sample tokens based on the computed weights tokens_to_mask = np.random.choice( len(weights), num_masked_per_batch, replace=False, p=weights / weights.sum() ) mask = torch.zeros(batch_size * seq_len) mask[tokens_to_mask] = 1 mask = mask.view(batch_size, seq_len).long().to(tokens.device) return mask