pytext.task package¶
Submodules¶
pytext.task.accelerator_lowering module¶
-
class
pytext.task.accelerator_lowering.
AcceleratorBiLSTM
(biLSTM)[source]¶ Bases:
torch.nn.modules.module.Module
-
forward
(embedded_tokens: torch.Tensor, seq_lengths: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor]) → Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]][source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
pytext.task.accelerator_lowering.
AcceleratorTransformer
(transformer)[source]¶ Bases:
torch.nn.modules.module.Module
-
forward
(tokens: torch.Tensor) → List[torch.Tensor][source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
pytext.task.accelerator_lowering.
AcceleratorTransformerLayersInternal
(layers)[source]¶ Bases:
torch.nn.modules.module.Module
-
forward
(encoded: torch.Tensor, padding_mask: torch.Tensor) → List[torch.Tensor][source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
pytext.task.accelerator_lowering.
accelerator_lstm_inputs
(model: torch.nn.modules.module.Module, trace: torch.jit.ScriptFunction, export_options: pytext.config.pytext_config.ExportConfig, dataset_iterable: Iterable[T_co], module_path)[source]¶
-
pytext.task.accelerator_lowering.
accelerator_transformerLayers_inputs
(model: torch.nn.modules.module.Module, trace: torch.jit.ScriptFunction, export_options: pytext.config.pytext_config.ExportConfig, dataset_iterable: Iterable[T_co], module_path)[source]¶
pytext.task.disjoint_multitask module¶
-
class
pytext.task.disjoint_multitask.
DisjointMultitask
(target_task_name, exporters, **kwargs)[source]¶ Bases:
pytext.task.task.TaskBase
Modules which have the same shared_module_key and type share parameters. Only the first instance of such module should be configured in tasks list.
-
export
(multitask_model, export_path, metric_channels, export_onnx_path=None)[source]¶ Wrapper method to export PyTorch model to Caffe2 model using
Exporter
.Parameters: - export_path (str) – file path of exported caffe2 model
- metric_channels – output the PyTorch model’s execution graph to
- export_onnx_path (str) – file path of exported onnx model
-
classmethod
from_config
(task_config: pytext.task.disjoint_multitask.DisjointMultitask.Config, metadata=None, model_state=None, tensorizers=None, rank=0, world_size=1)[source]¶ Create the task from config, and optionally load metadata/model_state This function will create components including
DataHandler
,Trainer
,MetricReporter
,Exporter
, and wire them up.Parameters: - task_config (Task.Config) – the config of the current task
- metadata – saved global context of this task, e.g: vocabulary, will be
generated by
DataHandler
if it’s None - model_state – saved model parameters, will be loaded into model when given
-
-
class
pytext.task.disjoint_multitask.
NewDisjointMultitask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task._NewTask
Multitask training based on underlying subtasks. To share parameters between modules from different tasks, specify the same shared_module_key. Only the first instance of each shared module should be configured in tasks list. Only the multitask trainer (not the per-task trainers) is used.
-
export
(model, export_path, metric_channels=None, export_onnx_path=None)[source]¶ Wrapper method to export PyTorch model to Caffe2 model using
Exporter
.Parameters: - export_path (str) – file path of exported caffe2 model
- metric_channels (List[Channel]) – outputs of model’s execution graph
- export_onnx_path (str) – file path of exported onnx model
-
classmethod
from_config
(task_config: pytext.task.disjoint_multitask.NewDisjointMultitask.Config, unused_metadata=None, model_state=None, tensorizers=None, rank=0, world_size=1)[source]¶ Create the task from config, and optionally load metadata/model_state This function will create components including
DataHandler
,Trainer
,MetricReporter
,Exporter
, and wire them up.Parameters: - task_config (Task.Config) – the config of the current task
- metadata – saved global context of this task, e.g: vocabulary, will be
generated by
DataHandler
if it’s None - model_state – saved model parameters, will be loaded into model when given
-
pytext.task.new_task module¶
-
class
pytext.task.new_task.
NewTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task._NewTask
-
pytext.task.new_task.
create_schema
(tensorizers: Dict[str, pytext.data.tensorizers.Tensorizer], extra_schema: Optional[Dict[str, Type[CT_co]]] = None) → Dict[str, Type[CT_co]][source]¶
-
pytext.task.new_task.
create_tensorizers
(model_inputs: Union[pytext.models.model.BaseModel.Config.ModelInput, Dict[str, pytext.data.tensorizers.Tensorizer.Config]]) → Dict[str, pytext.data.tensorizers.Tensorizer][source]¶
-
pytext.task.new_task.
sort
(input, dim=-1, descending=False, *, out=None) -> (Tensor, LongTensor)¶ Sorts the elements of the
input
tensor along a given dimension in ascending order by value.If
dim
is not given, the last dimension of the input is chosen.If
descending
isTrue
then the elements are sorted in descending order by value.A namedtuple of (values, indices) is returned, where the values are the sorted values and indices are the indices of the elements in the original input tensor.
Parameters: - input (Tensor) – the input tensor.
- dim (int, optional) – the dimension to sort along
- descending (bool, optional) – controls the sorting order (ascending or descending)
Keyword Arguments: out (tuple, optional) – the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers
Example:
>>> x = torch.randn(3, 4) >>> sorted, indices = torch.sort(x) >>> sorted tensor([[-0.2162, 0.0608, 0.6719, 2.3332], [-0.5793, 0.0061, 0.6058, 0.9497], [-0.5071, 0.3343, 0.9553, 1.0960]]) >>> indices tensor([[ 1, 0, 2, 3], [ 3, 1, 0, 2], [ 0, 3, 1, 2]]) >>> sorted, indices = torch.sort(x, 0) >>> sorted tensor([[-0.5071, -0.2162, 0.6719, -0.5793], [ 0.0608, 0.0061, 0.9497, 0.3343], [ 0.6058, 0.9553, 1.0960, 2.3332]]) >>> indices tensor([[ 2, 0, 0, 1], [ 0, 1, 1, 2], [ 1, 2, 2, 0]])
pytext.task.nop_decorator module¶
pytext.task.quantize module¶
pytext.task.serialize module¶
-
class
pytext.task.serialize.
CheckpointManager
[source]¶ Bases:
pytext.task.serialize.PyTextCheckpointManagerInterface
-
get_latest_checkpoint_path
() → str[source]¶ Return most recent saved checkpoint path in str Returns: checkpoint_path (str)
-
list
() → List[str][source]¶ Return all existing checkpoint paths Returns: checkpoint_path_list (List[str]), list elements are in the same order of checkpoint saving
-
load
(load_path: str)[source]¶ Loads a checkpoint/snapshot from disk. :param load_path: the file path from which to load :type load_path: str
Returns: De-serialized state (dictionary) that was saved
-
save_checkpoint
(state, checkpoint_path)[source]¶ Serialize and save checkpoint to given path. State is a dictionary that represents the all data to be saved. :param state: Dictionary containing data to be saved :param checkpoint_path: path of file to save checkpoint
-
save_snapshot
(state, snapshot_path)[source]¶ Serialize and save post-training model snapshot to given path. State is a dictionary that represents the all data to be saved. Having a separate method for snapshots enables future optimizations like quantization to be applied to snapshots.
Parameters: - state – Dictionary containing data to be saved
- snapshot_path – path of file to save snapshot
-
-
class
pytext.task.serialize.
PyTextCheckpointManagerInterface
[source]¶ Bases:
abc.ABC
CheckpointManager is a class abstraction to manage a training job’s checkpoints/snapshots with different IO and storage.
-
get_latest_checkpoint_path
() → str[source]¶ Return most recent saved checkpoint path in str Returns: checkpoint_path (str)
-
list
() → List[str][source]¶ Return all existing checkpoint paths Returns: checkpoint_path_list (List[str]), list elements are in the same order of checkpoint saving
-
load
(load_path: str)[source]¶ Loads a checkpoint/snapshot from disk. :param load_path: the file path from which to load :type load_path: str
Returns: De-serialized state (dictionary) that was saved
-
save_checkpoint
(state, checkpoint_path)[source]¶ Serialize and save checkpoint to given path. State is a dictionary that represents the all data to be saved. :param state: Dictionary containing data to be saved :param checkpoint_path: path of file to save checkpoint
-
save_snapshot
(state, snapshot_path)[source]¶ Serialize and save post-training model snapshot to given path. State is a dictionary that represents the all data to be saved. Having a separate method for snapshots enables future optimizations like quantization to be applied to snapshots.
Parameters: - state – Dictionary containing data to be saved
- snapshot_path – path of file to save snapshot
-
-
pytext.task.serialize.
generate_checkpoint_path
(config: pytext.config.pytext_config.PyTextConfig, identifier: str)[source]¶
-
pytext.task.serialize.
get_latest_checkpoint_path
(dir_path: Optional[str] = None) → str[source]¶ Get the latest checkpoint path :param dir_path: the dir to scan for existing checkpoint files. Default: if None, :param the latest checkpoint path saved in momery will be returned:
Returns: checkpoint_path
-
pytext.task.serialize.
load
(load_path: str, overwrite_config=None, rank=0, world_size=1)[source]¶ Load task, config and training state from a saved snapshot by default, it will construct the task using the saved config then load metadata and model state.
if overwrite_task is specified, it will construct the task using overwrite_task then load metadata and model state.
-
pytext.task.serialize.
save
(config: pytext.config.pytext_config.PyTextConfig, model: pytext.models.model.Model, meta: Optional[pytext.data.data_handler.CommonMetadata], tensorizers: Dict[str, pytext.data.tensorizers.Tensorizer], training_state: Optional[pytext.trainers.training_state.TrainingState] = None, identifier: Optional[str] = None) → str[source]¶ Save all stateful information of a training task to a specified file-like object, will save the original config, model state, metadata, training state if training is not completed Args: identifier (str): used to identify a checkpoint within a training job, used as a suffix for save path config (PytextConfig): contains all raw parameter/hyper-parameters for training task model (Model): actual model in training training_state (TrainingState): stateful infomation during training Returns: identifier (str): if identifier is not specified, will save to config.save_snapshot_path to be consistent to post-training snapshot; if specified, will be used to save checkpoint during training, identifier is used to identify checkpoints in the same training
-
pytext.task.serialize.
save_checkpoint
(f: io.IOBase, config: pytext.config.pytext_config.PyTextConfig, model: pytext.models.model.Model, meta: Optional[pytext.data.data_handler.CommonMetadata], tensorizers: Dict[str, pytext.data.tensorizers.Tensorizer], training_state: Optional[pytext.trainers.training_state.TrainingState] = None) → str[source]¶
pytext.task.task module¶
-
class
pytext.task.task.
TaskBase
(trainer: pytext.trainers.trainer.Trainer, data_handler: pytext.data.data_handler.DataHandler, model: pytext.models.model.Model, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, exporter: Optional[pytext.exporters.exporter.ModelExporter])[source]¶ Bases:
pytext.config.component.Component
Task is the central place to define and wire up components for data processing, model training, metric reporting, etc. Task class has a Config class containing the config of each component in a descriptive way.
-
export
(model, export_path, metric_channels=None, export_onnx_path=None)[source]¶ Wrapper method to export PyTorch model to Caffe2 model using
Exporter
.Parameters: - export_path (str) – file path of exported caffe2 model
- metric_channels (List[Channel]) – outputs of model’s execution graph
- export_onnx_path (str) – file path of exported onnx model
-
classmethod
format_prediction
(predictions, scores, context, target_meta)[source]¶ Format the prediction and score from model output, by default just return them in a dict
-
classmethod
from_config
(task_config, metadata=None, model_state=None, tensorizers=None, rank=1, world_size=0)[source]¶ Create the task from config, and optionally load metadata/model_state This function will create components including
DataHandler
,Trainer
,MetricReporter
,Exporter
, and wire them up.Parameters: - task_config (Task.Config) – the config of the current task
- metadata – saved global context of this task, e.g: vocabulary, will be
generated by
DataHandler
if it’s None - model_state – saved model parameters, will be loaded into model when given
-
predict
(examples)[source]¶ Generates predictions using PyTorch model. The difference with test() is that this should be used when the the examples do not have any true label/target.
Parameters: examples – json format examples, input names should match the names specified in this task’s features config
-
test
(test_path)[source]¶ Wrapper method to compute test metrics on holdout blind test dataset.
Parameters: test_path (str) – test data file path
-
train
(train_config, rank=0, world_size=1, training_state=None)[source]¶ Wrapper method to train the model using
Trainer
object.Parameters: - train_config (PyTextConfig) – config for training
- rank (int) – for distributed training only, rank of the gpu, default is 0
- world_size (int) – for distributed training only, total gpu to use, default is 1
-
-
class
pytext.task.task.
Task_Deprecated
(trainer: pytext.trainers.trainer.Trainer, data_handler: pytext.data.data_handler.DataHandler, model: pytext.models.model.Model, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, exporter: Optional[pytext.exporters.exporter.ModelExporter])[source]¶ Bases:
pytext.task.task.TaskBase
pytext.task.tasks module¶
-
class
pytext.task.tasks.
BertPairRegressionTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶
-
class
pytext.task.tasks.
DocumentClassificationTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
DocumentRegressionTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
EnsembleTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
IntentSlotTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
LMTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
MaskedLMTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
NewBertClassificationTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶
-
class
pytext.task.tasks.
NewBertPairClassificationTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶
-
class
pytext.task.tasks.
PairwiseClassificationForDenseRetrievalTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: pytext.metric_reporters.classification_metric_reporter.ClassificationMetricReporter, trainer: pytext.trainers.trainer.TaskTrainer, trace_both_encoders: bool = True)[source]¶ Bases:
pytext.task.tasks.PairwiseClassificationTask
This task is to implement DPR training in PyText. Code pointer: https://github.com/fairinternal/DPR/tree/master/dpr
-
class
pytext.task.tasks.
PairwiseClassificationTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: pytext.metric_reporters.classification_metric_reporter.ClassificationMetricReporter, trainer: pytext.trainers.trainer.TaskTrainer, trace_both_encoders: bool = True)[source]¶ Bases:
pytext.task.new_task.NewTask
-
classmethod
from_config
(config: pytext.task.tasks.PairwiseClassificationTask.Config, unused_metadata=None, model_state=None, tensorizers=None, rank=0, world_size=1)[source]¶ Create the task from config, and optionally load metadata/model_state This function will create components including
DataHandler
,Trainer
,MetricReporter
,Exporter
, and wire them up.Parameters: - task_config (Task.Config) – the config of the current task
- metadata – saved global context of this task, e.g: vocabulary, will be
generated by
DataHandler
if it’s None - model_state – saved model parameters, will be loaded into model when given
-
classmethod
-
class
pytext.task.tasks.
PairwiseRegressionTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: pytext.metric_reporters.classification_metric_reporter.ClassificationMetricReporter, trainer: pytext.trainers.trainer.TaskTrainer, trace_both_encoders: bool = True)[source]¶
-
class
pytext.task.tasks.
QueryDocumentPairwiseRankingTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
RoBERTaNERTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
SemanticParsingTask
(data: pytext.data.data.Data, model: pytext.models.semantic_parsers.rnng.rnng_parser.RNNGParser, metric_reporter: pytext.metric_reporters.compositional_metric_reporter.CompositionalMetricReporter, trainer: pytext.trainers.hogwild_trainer.HogwildTrainer)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
SeqNNTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
SequenceLabelingTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
SquadQATask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
-
class
pytext.task.tasks.
WordTaggingTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task.NewTask
Module contents¶
-
class
pytext.task.
NewTask
(data: pytext.data.data.Data, model: pytext.models.model.BaseModel, metric_reporter: Optional[pytext.metric_reporters.metric_reporter.MetricReporter] = None, trainer: Optional[pytext.trainers.trainer.TaskTrainer] = None)[source]¶ Bases:
pytext.task.new_task._NewTask
-
class
pytext.task.
Task_Deprecated
(trainer: pytext.trainers.trainer.Trainer, data_handler: pytext.data.data_handler.DataHandler, model: pytext.models.model.Model, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, exporter: Optional[pytext.exporters.exporter.ModelExporter])[source]¶ Bases:
pytext.task.task.TaskBase
-
class
pytext.task.
TaskBase
(trainer: pytext.trainers.trainer.Trainer, data_handler: pytext.data.data_handler.DataHandler, model: pytext.models.model.Model, metric_reporter: pytext.metric_reporters.metric_reporter.MetricReporter, exporter: Optional[pytext.exporters.exporter.ModelExporter])[source]¶ Bases:
pytext.config.component.Component
Task is the central place to define and wire up components for data processing, model training, metric reporting, etc. Task class has a Config class containing the config of each component in a descriptive way.
-
export
(model, export_path, metric_channels=None, export_onnx_path=None)[source]¶ Wrapper method to export PyTorch model to Caffe2 model using
Exporter
.Parameters: - export_path (str) – file path of exported caffe2 model
- metric_channels (List[Channel]) – outputs of model’s execution graph
- export_onnx_path (str) – file path of exported onnx model
-
classmethod
format_prediction
(predictions, scores, context, target_meta)[source]¶ Format the prediction and score from model output, by default just return them in a dict
-
classmethod
from_config
(task_config, metadata=None, model_state=None, tensorizers=None, rank=1, world_size=0)[source]¶ Create the task from config, and optionally load metadata/model_state This function will create components including
DataHandler
,Trainer
,MetricReporter
,Exporter
, and wire them up.Parameters: - task_config (Task.Config) – the config of the current task
- metadata – saved global context of this task, e.g: vocabulary, will be
generated by
DataHandler
if it’s None - model_state – saved model parameters, will be loaded into model when given
-
predict
(examples)[source]¶ Generates predictions using PyTorch model. The difference with test() is that this should be used when the the examples do not have any true label/target.
Parameters: examples – json format examples, input names should match the names specified in this task’s features config
-
test
(test_path)[source]¶ Wrapper method to compute test metrics on holdout blind test dataset.
Parameters: test_path (str) – test data file path
-
train
(train_config, rank=0, world_size=1, training_state=None)[source]¶ Wrapper method to train the model using
Trainer
object.Parameters: - train_config (PyTextConfig) – config for training
- rank (int) – for distributed training only, rank of the gpu, default is 0
- world_size (int) – for distributed training only, total gpu to use, default is 1
-
-
pytext.task.
save
(config: pytext.config.pytext_config.PyTextConfig, model: pytext.models.model.Model, meta: Optional[pytext.data.data_handler.CommonMetadata], tensorizers: Dict[str, pytext.data.tensorizers.Tensorizer], training_state: Optional[pytext.trainers.training_state.TrainingState] = None, identifier: Optional[str] = None) → str[source]¶ Save all stateful information of a training task to a specified file-like object, will save the original config, model state, metadata, training state if training is not completed Args: identifier (str): used to identify a checkpoint within a training job, used as a suffix for save path config (PytextConfig): contains all raw parameter/hyper-parameters for training task model (Model): actual model in training training_state (TrainingState): stateful infomation during training Returns: identifier (str): if identifier is not specified, will save to config.save_snapshot_path to be consistent to post-training snapshot; if specified, will be used to save checkpoint during training, identifier is used to identify checkpoints in the same training
-
pytext.task.
load
(load_path: str, overwrite_config=None, rank=0, world_size=1)[source]¶ Load task, config and training state from a saved snapshot by default, it will construct the task using the saved config then load metadata and model state.
if overwrite_task is specified, it will construct the task using overwrite_task then load metadata and model state.
-
pytext.task.
create_task
(task_config, metadata=None, model_state=None, tensorizers=None, rank=0, world_size=1)[source]¶ Create a task by finding task class in registry and invoking the from_config function of the class, see
from_config()
for more details