Source code for pytext.data.xlm_dictionary

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import os
from logging import getLogger

import numpy as np
import torch
from pytext.utils.file_io import PathManager


logger = getLogger()


BOS_WORD = "<s>"
EOS_WORD = "</s>"
PAD_WORD = "<pad>"
UNK_WORD = "<unk>"

SPECIAL_WORD = "<special%i>"
SPECIAL_WORDS = 10

SEP_WORD = SPECIAL_WORD % 0
MASK_WORD = SPECIAL_WORD % 1


[docs]class Dictionary(object): def __init__(self, id2word, word2id, counts): assert len(id2word) == len(word2id) == len(counts) self.id2word = id2word self.word2id = word2id self.counts = counts self.bos_index = word2id[BOS_WORD] self.eos_index = word2id[EOS_WORD] self.pad_index = word2id[PAD_WORD] self.unk_index = word2id[UNK_WORD] self.check_valid() def __len__(self): """ Returns the number of words in the dictionary. """ return len(self.id2word) def __getitem__(self, i): """ Returns the word of the specified index. """ return self.id2word[i] def __contains__(self, w): """ Returns whether a word is in the dictionary. """ return w in self.word2id def __eq__(self, y): """ Compare this dictionary with another one. """ self.check_valid() y.check_valid() if len(self.id2word) != len(y): return False return all(self.id2word[i] == y[i] for i in range(len(y)))
[docs] def check_valid(self): """ Check that the dictionary is valid. """ assert self.bos_index == 0 assert self.eos_index == 1 assert self.pad_index == 2 assert self.unk_index == 3 assert all( self.id2word[4 + i] == SPECIAL_WORD % i for i in range(SPECIAL_WORDS) ) assert len(self.id2word) == len(self.word2id) == len(self.counts) assert set(self.word2id.keys()) == set(self.counts.keys()) for i in range(len(self.id2word)): assert self.word2id[self.id2word[i]] == i last_count = 1e18 for i in range(4 + SPECIAL_WORDS, len(self.id2word) - 1): count = self.counts[self.id2word[i]] assert count <= last_count last_count = count
[docs] def index(self, word, no_unk=False): """ Returns the index of the specified word. """ if no_unk: return self.word2id[word] else: return self.word2id.get(word, self.unk_index)
[docs] def max_vocab(self, max_vocab): """ Limit the vocabulary size. """ assert max_vocab >= 1 init_size = len(self) self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab} self.word2id = {v: k for k, v in self.id2word.items()} self.counts = {k: v for k, v in self.counts.items() if k in self.word2id} self.check_valid() logger.info( "Maximum vocabulary size: %i. Dictionary size: %i -> %i (removed %i words)." % (max_vocab, init_size, len(self), init_size - len(self)) )
[docs] def min_count(self, min_count): """ Threshold on the word frequency counts. """ assert min_count >= 0 init_size = len(self) self.id2word = { k: v for k, v in self.id2word.items() if self.counts[self.id2word[k]] >= min_count or k < 4 + SPECIAL_WORDS } self.word2id = {v: k for k, v in self.id2word.items()} self.counts = {k: v for k, v in self.counts.items() if k in self.word2id} self.check_valid() logger.info( "Minimum frequency count: %i. Dictionary size: %i -> %i (removed %i words)." % (min_count, init_size, len(self), init_size - len(self)) )
[docs] @staticmethod def read_vocab(vocab_path): """ Create a dictionary from a vocabulary file. """ skipped = 0 assert PathManager.isfile(vocab_path), vocab_path word2id = {BOS_WORD: 0, EOS_WORD: 1, PAD_WORD: 2, UNK_WORD: 3} for i in range(SPECIAL_WORDS): word2id[SPECIAL_WORD % i] = 4 + i counts = {k: 0 for k in word2id.keys()} f = PathManager.open(vocab_path, "r", encoding="utf-8") for i, line in enumerate(f): if "\u2028" in line: skipped += 1 continue line = line.rstrip().split() if len(line) != 2: skipped += 1 continue assert len(line) == 2, (i, line) # assert line[0] not in word2id and line[1].isdigit(), (i, line) assert line[1].isdigit(), (i, line) if line[0] in word2id: skipped += 1 print("%s already in vocab" % line[0]) continue if not line[1].isdigit(): skipped += 1 print("Empty word at line %s with count %s" % (i, line)) continue # shift because of extra words word2id[line[0]] = 4 + SPECIAL_WORDS + i - skipped counts[line[0]] = int(line[1]) f.close() id2word = {v: k for k, v in word2id.items()} dico = Dictionary(id2word, word2id, counts) logger.info("Read %i words from the vocabulary file." % len(dico)) if skipped > 0: logger.warning("Skipped %i empty lines!" % skipped) return dico
[docs] @staticmethod def index_data(path, bin_path, dico): """ Index sentences with a dictionary. """ if bin_path is not None and PathManager.isfile(bin_path): print("Loading data from %s ..." % bin_path) data = torch.load(bin_path) assert dico == data["dico"] return data positions = [] sentences = [] unk_words = {} # index sentences f = PathManager.open(path, "r", encoding="utf-8") for i, line in enumerate(f): if i % 1000000 == 0 and i > 0: print(i) s = line.rstrip().split() # skip empty sentences if len(s) == 0: print("Empty sentence in line %i." % i) # index sentence words count_unk = 0 indexed = [] for w in s: word_id = dico.index(w, no_unk=False) # if we find a special word which is not an unknown word, # skip the sentence if 0 <= word_id < 4 + SPECIAL_WORDS and word_id != 3: logger.warning( 'Found unexpected special word "%s" (%i)!!' % (w, word_id) ) continue assert word_id >= 0 indexed.append(word_id) if word_id == dico.unk_index: unk_words[w] = unk_words.get(w, 0) + 1 count_unk += 1 # add sentence positions.append([len(sentences), len(sentences) + len(indexed)]) sentences.extend(indexed) sentences.append(1) # EOS index f.close() # tensorize data positions = np.int64(positions) if len(dico) < 1 << 16: sentences = np.uint16(sentences) elif len(dico) < 1 << 31: sentences = np.int32(sentences) else: raise Exception("Dictionary is too big.") assert sentences.min() >= 0 data = { "dico": dico, "positions": positions, "sentences": sentences, "unk_words": unk_words, } if bin_path is not None: print("Saving the data to %s ..." % bin_path) torch.save(data, bin_path, pickle_protocol=4) return data