Source code for pytext.utils

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from collections.abc import Sequence


[docs]def cls_vars(cls): return [v for n, v in vars(cls).items() if not n.startswith("_")]
[docs]def set_random_seeds(seed, use_deterministic_cudnn): import random import numpy as np import torch from pytext.utils import cuda # See https://pytorch.org/docs/stable/notes/randomness.html random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if cuda.CUDA_ENABLED and use_deterministic_cudnn: print( """WARNING: Your training might be slower because you have set use_deterministic_cudnn flag to True. Read https://pytorch.org/docs/stable/notes/randomness.html and https://discuss.pytorch.org/t/what-is-the-differenc-between-cudnn-deterministic-and-cudnn-benchmark/38054 """ ) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
[docs]def recursive_map(seq, func): """This is similar to the build-in map function but works for nested lists. Useful for transforming tensors serialized with .tolist() """ for item in seq: if isinstance(item, Sequence): yield type(item)(recursive_map(item, func)) else: yield func(item)
[docs]def round_seq(seq, ndigits): """Rounds a nested sequence of floats to ndigits precision. Useful for rounding tensors serialized with .tolist() """ return type(seq)(recursive_map(seq, lambda item: round(item, ndigits)))