Source code for pytext.optimizer.sparsifiers.sparsifier

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools
import json
import math
import os
import sys
from enum import Enum
from typing import List

import numpy as np
import torch
import torch.nn as nn
from pytext.common.constants import Stage
from pytext.config import ConfigBase
from pytext.config.component import Component, ComponentType
from pytext.models.crf import CRF
from pytext.models.model import Model
from pytext.utils import timing
from pytext.utils.file_io import PathManager


[docs]class State(Enum): ANALYSIS = "Analysis" OTHERS = "Others"
[docs]class Sparsifier(Component): __COMPONENT_TYPE__ = ComponentType.SPARSIFIER __EXPANSIBLE__ = True
[docs] class Config(ConfigBase): pass
[docs] def sparsify(self, *args, **kwargs): pass
[docs] def sparsification_condition(self, *args, **kwargs): pass
[docs] def get_sparsifiable_params(self, *args, **kwargs): pass
[docs] def initialize(self, *args, **kwargs): pass
[docs] def op_pre_epoch(self, *args, **kwargs): pass
[docs] def save_model_state_for_all_rank(self): return False
[docs] def get_current_sparsity(self, model: Model) -> float: trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) nonzero_params = sum( p.nonzero().size(0) for p in model.parameters() if p.requires_grad ) return (trainable_params - nonzero_params) / trainable_params
[docs]class L0_projection_sparsifier(Sparsifier): """ L0 projection-based (unstructured) sparsification Args: weights (torch.Tensor): input weight matrix sparsity (float32): the desired sparsity [0-1] """
[docs] class Config(Sparsifier.Config): sparsity: float = 0.9 starting_epoch: int = 2 frequency: int = 1 layerwise_pruning: bool = True accumulate_mask: bool = False
def __init__( self, sparsity, starting_epoch, frequency, layerwise_pruning=True, accumulate_mask=False, ): assert 0 <= sparsity <= 1 self.sparsity = sparsity assert starting_epoch >= 1 self.starting_epoch = starting_epoch assert frequency >= 1 self.frequency = frequency self.layerwise_pruning = layerwise_pruning self.accumulate_mask = accumulate_mask self._masks = None
[docs] @classmethod def from_config(cls, config: Config): return cls( config.sparsity, config.starting_epoch, config.frequency, config.layerwise_pruning, config.accumulate_mask, )
[docs] def sparsification_condition(self, state): return ( state.stage == Stage.TRAIN and state.epoch >= self.starting_epoch and state.step_counter % self.frequency == 0 )
[docs] def sparsify(self, state): """ obtain a mask and apply the mask to sparsify """ model = state.model # compute new mask when conditions are True if self.sparsification_condition(state): masks = self.get_masks(model) # applied the computed mask, self.accumulate_mask handled separately if not self.accumulate_mask: self.apply_masks(model, masks) # if self.accumulate_mask is True, apply the existent mask irregardless Stage if self.accumulate_mask and self._masks is not None: self.apply_masks(model, self._masks)
[docs] def get_sparsifiable_params(self, model: Model): sparsifiable_params = [p for p in model.parameters() if p.requires_grad] return sparsifiable_params
[docs] def apply_masks(self, model: Model, masks: List[torch.Tensor]): """ apply given masks to zero-out learnable weights in model """ learnableparams = self.get_sparsifiable_params(model) assert len(learnableparams) == len(masks) for m, w in zip(masks, learnableparams): if len(m.size()): assert m.size() == w.size() w.data *= m.clone() # if accumulate_mask, remove a param permanently by also removing # its gradient if self.accumulate_mask: w.grad.data *= m.clone()
[docs] def get_masks( self, model: Model, pre_masks: List[torch.Tensor] = None ) -> List[torch.Tensor]: """ Note: this function returns the masks only but do not sparsify or modify the weights prune x% of weights among the weights with "1" in pre_masks Args: model: Model pre_masks: list of FloatTensors where "1" means retained the weight and "0" means pruned the weight Return: masks: List[torch.Tensor], intersection of new masks and pre_masks, so that "1" only if the weight is selected after new masking and pre_mask """ learnableparams = self.get_sparsifiable_params(model) if pre_masks: self._masks = pre_masks if self._masks is None: # retain everything if no pre_masks given self._masks = [torch.ones_like(p) for p in learnableparams] assert len(learnableparams) == len(self._masks) for m, w in zip(self._masks, learnableparams): if len(m.size()): assert m.size() == w.size() if self.layerwise_pruning: masks = [] for m, param in zip(self._masks, learnableparams): weights_abs = torch.abs(param.data).to(param.device) # absolute value of weights selected from existent masks weights_abs_masked_flat = torch.flatten(weights_abs[m.bool()]) total_size = weights_abs_masked_flat.numel() if total_size > 0: # using ceil instead of floor() or int() # because at least one element in the tensor required to be selected max_num_nonzeros = math.ceil(total_size * (1 - self.sparsity)) # only pruned among the weights slected from existent masks topkval = ( torch.topk(weights_abs_masked_flat, max_num_nonzeros) .values.min() .item() ) # intersection of the new mask and pre_mexistent masks, # mask == 1 retain, mask == 0 pruned, mask = (weights_abs >= topkval).float() * m else: mask = param.new_empty(()) masks.append(mask) else: # concatenated flatten tensor of learnableparams that have _masks as True learnableparams_masked_flat = torch.cat( [ torch.flatten(p[m.bool()]) for m, p in zip(self._masks, learnableparams) ], dim=0, ) # using ceil instead of floor() or int() because at least one element # in the tensor required to be selected max_num_nonzeros = math.ceil( learnableparams_masked_flat.numel() * (1 - self.sparsity) ) # select globally the top-k th weight among weights selected from _masks topkval = ( torch.topk(torch.abs(learnableparams_masked_flat), max_num_nonzeros) .values.min() .item() ) # intersection of the new mask and _masks, # mask == 1 retain, mask == 0 pruned, masks = [ (torch.abs(p.data) >= topkval).float() * m if p.numel() > 0 else p.new_empty(()) for m, p in zip(self._masks, learnableparams) ] if self.accumulate_mask: self._masks = masks return masks
[docs]class CRF_SparsifierBase(Sparsifier):
[docs] class Config(Sparsifier.Config): starting_epoch: int = 1 frequency: int = 1
[docs] def sparsification_condition(self, state): if state.stage == Stage.TRAIN: return False return ( state.epoch >= self.starting_epoch and state.step_counter % self.frequency == 0 )
[docs] def get_sparsifiable_params(self, model: nn.Module): for m in model.modules(): if isinstance(m, CRF): return m.transitions.data
[docs] def get_transition_sparsity(self, transition): nonzero_params = transition.nonzero().size(0) return (transition.numel() - nonzero_params) / transition.numel()
[docs]class CRF_L1_SoftThresholding(CRF_SparsifierBase): """ implement l1 regularization: min Loss(x, y, CRFparams) + lambda_l1 * ||CRFparams||_1 and solve the optimiation problem via (stochastic) proximal gradient-based method i.e., soft-thresholding param_updated = sign(CRFparams) * max ( abs(CRFparams) - lambda_l1, 0) """
[docs] class Config(CRF_SparsifierBase.Config): lambda_l1: float = 0.001
def __init__(self, lambda_l1: float, starting_epoch: int, frequency: int): self.lambda_l1 = lambda_l1 assert starting_epoch >= 1 self.starting_epoch = starting_epoch assert frequency >= 1 self.frequency = frequency
[docs] @classmethod def from_config(cls, config: Config): return cls(config.lambda_l1, config.starting_epoch, config.frequency)
[docs] def sparsify(self, state): if not self.sparsification_condition(state): return model = state.model transition_matrix = self.get_sparsifiable_params(model) transition_matrix_abs = torch.abs(transition_matrix) assert ( len(state.optimizer.param_groups) == 1 ), "different learning rates for multiple param groups not supported" lrs = state.optimizer.param_groups[0]["lr"] threshold = self.lambda_l1 * lrs transition_matrix = torch.sign(transition_matrix) * torch.max( (transition_matrix_abs - threshold), transition_matrix.new_zeros(transition_matrix.shape), ) current_sparsity = self.get_transition_sparsity(transition_matrix) print(f"sparsity of CRF transition matrix: {current_sparsity}")
[docs]class CRF_MagnitudeThresholding(CRF_SparsifierBase): """ magnitude-based (equivalent to projection onto l0 constraint set) sparsification on CRF transition matrix. Preserveing the top-k elements either rowwise or columnwise until sparsity constraint is met. """
[docs] class Config(CRF_SparsifierBase.Config): sparsity: float = 0.9 grouping: str = "row"
def __init__(self, sparsity, starting_epoch, frequency, grouping): assert 0 <= sparsity <= 1 self.sparsity = sparsity assert starting_epoch >= 1 self.starting_epoch = starting_epoch assert frequency >= 1 self.frequency = frequency assert ( grouping == "row" or grouping == "column" ), "grouping needs to be row or column" self.grouping = grouping
[docs] @classmethod def from_config(cls, config: Config): return cls( config.sparsity, config.starting_epoch, config.frequency, config.grouping )
[docs] def sparsify(self, state): if not self.sparsification_condition(state): return model = state.model transition_matrix = self.get_sparsifiable_params(model) num_rows, num_cols = transition_matrix.shape trans_abs = torch.abs(transition_matrix) if self.grouping == "row": max_num_nonzeros = math.ceil(num_cols * (1 - self.sparsity)) topkvals = ( torch.topk(trans_abs, k=max_num_nonzeros, dim=1) .values.min(dim=1, keepdim=True) .values ) else: max_num_nonzeros = math.ceil(num_rows * (1 - self.sparsity)) topkvals = ( torch.topk(trans_abs, k=max_num_nonzeros, dim=0) .values.min(dim=0, keepdim=True) .values ) # trans_abs < topkvals is a broadcasted comparison transition_matrix[trans_abs < topkvals] = 0.0 current_sparsity = self.get_transition_sparsity(transition_matrix) print(f"sparsity of CRF transition matrix: {current_sparsity}")
[docs]class SensitivityAnalysisSparsifier(Sparsifier):
[docs] class Config(Sparsifier.Config): pre_train_model_path: str = "" analyzed_sparsity: float = 0.8 # we don't use all eval data for analysis, only use a portion of the data. max_analysis_batches: int = 0 # allow the user to skip pruning for some weight. Here we set the max # number of weight tensor can be skipped for pruning. max_skipped_weight: int = 0 # if we already did sensitivity analysis before pre_analysis_path: str = "" sparsity: float = 0.8 # if we use iterative pruning iterative_pruning: bool = True # the total number of pruning iterations for iterative pruning, where where # we incrementally increase the sparsity at each iteration pruning_iterations: int = 2 # the ratio of the start sparsity to the final sparsity start_sparsity_ratio: float = 0.5
def __init__( self, pre_train_model_path, analyzed_sparsity, max_analysis_batches, max_skipped_weight, pre_analysis_path, sparsity, iterative_pruning, pruning_iterations, start_sparsity_ratio, ): assert PathManager.exists( pre_train_model_path ), "The pre-trained model must be exist" self.pre_train_model_path = pre_train_model_path self.param_dict = None assert ( 0.0 <= analyzed_sparsity <= 1.0 ), "Analyzed sparsity need to be in the range of [0, 1]" self.analyzed_sparsity = analyzed_sparsity self.max_analysis_batches = max_analysis_batches self.max_skipped_weight = max_skipped_weight self.require_mask_parameters = [] self.pre_analysis_path = pre_analysis_path assert ( 0.0 <= sparsity <= 1.0 ), "Pruning sparsity need to be in the range of [0, 1]" self.sparsity = sparsity self._masks = None self.analysis_state = State.OTHERS self.iterative_pruning = iterative_pruning # members used for iterative pruning if self.iterative_pruning: assert ( pruning_iterations > 1 ), "iterative pruning should contains at least two pruning iterations" self.pruning_iterations = pruning_iterations self.start_sparsity = start_sparsity_ratio * sparsity self.end_sparsity = self.sparsity self.epochs_per_iter = 0 self.sparsity_increment = 0.0
[docs] @classmethod def from_config(cls, config: Config): return cls( config.pre_train_model_path, config.analyzed_sparsity, config.max_analysis_batches, config.max_skipped_weight, config.pre_analysis_path, config.sparsity, config.iterative_pruning, config.pruning_iterations, config.start_sparsity_ratio, )
[docs] def get_sparsifiable_params(self, model): param_dict = {} for module_name, m in model.named_modules(): # Search the name of all module_name in named_modules # only test the parameters in nn.Linear if isinstance(m, nn.Linear): # module_name: module.xxx # param_name: module.xxx.weight # we only check weight tensor param_name = module_name + ".weight" param_dict[param_name] = m.weight return param_dict
[docs] def get_mask_for_param(self, param, sparsity): """ generate the prune mask for one weight tensor. """ n = int(sparsity * param.nelement()) if n > 0: # If n > 0, we need to remove n parameters, the threshold # equals to the n-th largest parameters.x threshold = float(param.abs().flatten().kthvalue(n - 1)[0]) else: # If n == 0, it means all parameters need to be kept. # Because the absolute parameter value >= 0, setting # threshold to -1 ensures param.abs().ge(threshold) # is True for all the parameters. threshold = -1.0 # reverse_mask indiciates the weights that need to be kept mask = param.abs().ge(threshold).float() return mask
[docs] def layer_wise_analysis( self, param_name, param_dict, trainer, state, eval_data, metric_reporter ): # perform pruning for the target param with param_name if param_name is None: prunable_param_shape = None else: prunable_param = param_dict[param_name] # include the shape information for better analysis prunable_param_shape = list(prunable_param.shape) mask = self.get_mask_for_param(prunable_param, self.analyzed_sparsity) with torch.no_grad(): param_dict[param_name].data.mul_(mask) # get the eval_metric for the pruned model with torch.no_grad(): # set the number of batches of eval data for analysis analysis_data = eval_data if self.max_analysis_batches > 0: analysis_data = itertools.islice(eval_data, self.max_analysis_batches) eval_metric = trainer.run_epoch(state, analysis_data, metric_reporter) current_metric = metric_reporter.get_model_select_metric(eval_metric) if metric_reporter.lower_is_better: current_metric = -current_metric return current_metric, prunable_param_shape
[docs] def find_params_to_prune(self, metric_dict, max_skip_weight_num): require_mask_parameters = sorted( metric_dict.keys(), reverse=True, key=lambda param: metric_dict[param] ) metric_sensitivities_by_param = [ metric_dict[p] for p in require_mask_parameters ] skipped_weight_num = 0 while skipped_weight_num < max_skip_weight_num: # calculate the mean and sandard deviation mean_ = np.mean(metric_sensitivities_by_param[:-skipped_weight_num]) std_ = np.std(metric_sensitivities_by_param[:-skipped_weight_num]) # skip runing of the parameter if the metric disensitivity is # less than mean_ - 3 * std_, otherwise break. if ( metric_sensitivities_by_param[-skipped_weight_num - 1] >= mean_ - 3 * std_ ): break skipped_weight_num += 1 require_mask_parameters = require_mask_parameters[:-skipped_weight_num] # return how many weight are skipped during this iteration return require_mask_parameters, skipped_weight_num
[docs] def sensitivity_analysis( self, trainer, state, eval_data, metric_reporter, train_config ): """ Analysis the sensitivity of each weight tensor to the metric. Prune the weight tensor one by one and evaluate the metric if the correspond weight tensor is pruned. Args: trainer (trainer): batch iterator of training data state (TrainingState): the state of the current training eval_data (BatchIterator): batch iterator of evaluation data metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config Returns: analysis_result: a string of each layer sensitivity to metric. """ print("Analyzed_sparsity: {}".format(self.analyzed_sparsity)) print("Evaluation metric_reporter: {}".format(type(metric_reporter).__name__)) output_path = ( os.path.dirname(train_config.task.metric_reporter.output_path) + "/sensitivity_analysis_sparsifier.ckp" ) # param_dict: the dict maps weight tensor to the parameter name self.param_dict = self.get_sparsifiable_params(state.model) # set model to evaluation mode state.stage = Stage.EVAL state.model.eval(Stage.EVAL) metric_dict = {} all_param_list = [None] + list(self.param_dict.keys()) print("All prunable parameters", all_param_list) # print the sensitivity results for each weight print("#" * 40) print("save the analysis result to: ", output_path) print("Pruning Sensitivity Test: param / shape / eval metric") # iterate through all_param_list to test pruning snesitivity for param_name in all_param_list: print("=" * 40) print("Testing {}".format(param_name)) state.model.load_state_dict(self.loaded_model["model_state"]) current_metric, prunable_param_shape = self.layer_wise_analysis( param_name, self.param_dict, trainer, state, eval_data, metric_reporter ) if param_name is None: baseline_metric = current_metric metric_dict[param_name] = current_metric - baseline_metric print("#" * 40) # remove baseline metric from the analysis results if None in metric_dict: del metric_dict[None] # write the test result into the checkpoint if state.rank == 0: with PathManager.open(output_path, "w") as fp: json.dump(metric_dict, fp) return metric_dict
[docs] def sparsification_condition(self, state): return state.stage == Stage.TRAIN
[docs] def apply_masks(self, model: Model, masks: List[torch.Tensor]): """ apply given masks to zero-out learnable weights in model """ learnable_params = self.get_required_sparsifiable_params(model) assert len(learnable_params) == len(masks) for m, w in zip(masks, learnable_params): if len(m.size()): assert m.size() == w.size() w.data *= m
[docs] def get_current_sparsity(self, model: Model) -> float: trainable_params = sum( module.weight.data.numel() for name, module in model.named_modules() if isinstance(module, nn.Linear) ) nonzero_params = sum( module.weight.data.nonzero().size(0) for name, module in model.named_modules() if isinstance(module, nn.Linear) ) return (trainable_params - nonzero_params) / trainable_params
[docs] def sparsify(self, state): """ apply the mask to sparsify the weight tensor """ # do not sparsify the weight tensor during the analysis if self.analysis_state == State.ANALYSIS: return model = state.model # compute new mask when conditions are True if self.sparsification_condition(state): # applied the computed mask to sparsify the weight self.apply_masks(model, self._masks)
[docs] def get_required_sparsifiable_params(self, model: Model): # param_dict contains all parameters, select requied weights # if we reload analysis result from file, we need to calculate # all param_dict again. if self.param_dict is None: self.param_dict = self.get_sparsifiable_params(model) return [self.param_dict[p] for p in self.require_mask_parameters]
[docs] def get_masks(self, model: Model) -> List[torch.Tensor]: """ Note: this function returns the masks for each weight tensor if that tensor is required to be pruned prune x% of weights items among the weights with "1" in mask (self._mask) indicate the remained weights, with "0" indicate pruned weights Args: model: Model Return: masks: List[torch.Tensor], the prune mask for the weight of all layers """ learnable_params = self.get_required_sparsifiable_params(model) masks = [] for param in learnable_params: mask = self.get_mask_for_param(param, self.sparsity) masks.append(mask) return masks
[docs] def load_analysis_from_path(self): assert PathManager.isfile(self.pre_analysis_path), "{} is not a file".format( self.pre_analysis_path ) with PathManager.open(self.pre_analysis_path, "r") as fp: metric_dict = json.load(fp) return metric_dict
[docs] @timing.time("sparsifier initialize") def initialize(self, trainer, state, eval_data, metric_reporter, train_config): assert self.pre_train_model_path, "must have a pre-train model" # load the pretrained model print("load the pretrained model from: " + self.pre_train_model_path) self.loaded_model = torch.load( self.pre_train_model_path, map_location=torch.device("cpu") ) # if user specify the analysis file, load it from path if self.pre_analysis_path: metric_dict = self.load_analysis_from_path() else: self.analysis_state = State.ANALYSIS metric_dict = self.sensitivity_analysis( trainer, state, eval_data, metric_reporter, train_config ) # finish the analysis, sparsifier can apply prune mask. self.analysis_state = State.OTHERS # skip some of the weight tensors from pruning. The user can # specify the max_skipped_weight, which limit the max number # of weight to be skipped. self.require_mask_parameters, skipped_weight_num = self.find_params_to_prune( metric_dict, self.max_skipped_weight ) for p in self.require_mask_parameters: print(p, " ", metric_dict[p]) print("#" * 40) sys.stdout.flush() print(str(skipped_weight_num) + " weight tensors are skipped for pruning") if self.iterative_pruning: assert ( trainer.config.early_stop_after == 0 ), "Can not set early stop for iterative pruning" assert ( trainer.config.epochs % self.pruning_iterations == 0 ), "total training epochs should be divided by the pruning iterations" self.epochs_per_iter = trainer.config.epochs // self.pruning_iterations # init sparsity as self.start_sparsity, calculate the sparsity # increment of each pruning iteration. self.sparsity_increment = (self.end_sparsity - self.start_sparsity) / ( self.pruning_iterations - 1 ) self.sparsity = self.start_sparsity print( "sparsity start from: ", self.sparsity, " increment of: ", self.sparsity_increment, ) # pruning from a pre-trained weights state.model.load_state_dict(self.loaded_model["model_state"]) # initialize and generate the pruning mask. We don't want to generate # the mask for each step. Otherwise, it will be time inefficient. self._masks = self.get_masks(state.model)
[docs] def increase_sparsity(self, state): self.sparsity += self.sparsity_increment print("sparsity increased to: ", self.sparsity)
[docs] def save_model_state_for_all_rank(self): # all machines should save the best model of a pruning iteration # if we use iterative pruning return self.iterative_pruning
def _should_update_sparsity(self, epoch): return ( self.iterative_pruning and epoch % self.epochs_per_iter == 0 and epoch > 0 )
[docs] def op_pre_epoch(self, trainer, state): """ note: invoke this function at the begin of each pruning iteration. Each pruning iteration contains several epochs. In this function, we will: 1. update the sparsity, 2. reload the best model from the previous iteration, 3. generate the prune mask, and 4. apply the mask to prune the weight of the model with increased sparsity. """ # check if this epoch we need to update the pruning sparsity if self._should_update_sparsity(state.epoch): # init best model metric as None at the begin of each iteration. # this can make sure the best_model is chosen from previous iteration # instead of from the entire training. state.best_model_metric = None # load best model from previous pruning iteration assert state.best_model_state is not None trainer.load_best_model(state) # the sparsity is initialized as start_sparsity, increased every iteration self.increase_sparsity(state) # start from the second iteration, generate the new mask with increased sparsity self._masks = self.get_masks(state.model) self.apply_masks(state.model, self._masks)