#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
CUDA_ENABLED = False
DISTRIBUTED_WORLD_SIZE = 1
[docs]def Variable(data, *args, **kwargs):
if CUDA_ENABLED:
return torch.autograd.Variable(data.cuda(), *args, **kwargs)
else:
return torch.autograd.Variable(data, *args, **kwargs)
[docs]def var_to_numpy(v):
return (v.cpu() if CUDA_ENABLED else v).data.numpy()
[docs]def zerovar(*size):
return Variable(torch.zeros(*size))
[docs]def FloatTensor(*args):
if CUDA_ENABLED:
return torch.cuda.FloatTensor(*args)
else:
return torch.FloatTensor(*args)
[docs]def LongTensor(*args):
if CUDA_ENABLED:
return torch.cuda.LongTensor(*args)
else:
return torch.LongTensor(*args)
[docs]def GetTensor(tensor):
if CUDA_ENABLED:
return tensor.cuda()
else:
return tensor
[docs]def tensor(data, dtype):
return torch.tensor(data, dtype=dtype, device=device())
[docs]def device():
return "cuda:{}".format(torch.cuda.current_device()) if CUDA_ENABLED else "cpu"