Source code for pytext.optimizer.sparsifiers.blockwise_sparsifier

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
from typing import List

import torch
import torch.nn as nn
from pytext.optimizer.sparsifiers.sparsifier import L0_projection_sparsifier


[docs]class BlockwiseMagnitudeSparsifier(L0_projection_sparsifier): """ running blockwise magnitude-based sparsification Args: block_size: define the size of each block columnwise_blocking: define columnwise block if true starting_epoch: sparsification_condition returns true only after starting_epoch frequency: sparsification_condition only if number of steps devides frequency accumulate_mask: if true, the mask after each .sparisfy() will be reused sparsity: percentage of zeros among the **UNPRUNED** parameters. Examples on how the sparsifier work: 2D matrix: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 ] define 3 X 1 block [ ********* ******* *0 1 2* *3 4* ********** ******* *5 6 7* *8 9* ********** ******* *10 11 12* *13 14* ********** ******* *15 16 17* *18 19* ********** ******* *20 21 22* *23 24* ********** ******* ] compute l1 norm of each block and sort them. Retain blocks with largest absolute values until sparsity threshold is met """
[docs] class Config(L0_projection_sparsifier.Config): block_size: int = 16 columnwise_blocking: bool = False accumulate_mask: bool = False layerwise_pruning: bool = True
def __init__( self, sparsity, starting_epoch, frequency, block_size, columnwise_blocking, accumulate_mask, layerwise_pruning, ): super().__init__(sparsity, starting_epoch, frequency, layerwise_pruning) self.block_size = block_size self.columnwise_blocking = columnwise_blocking self.accumulate_mask = accumulate_mask self._masks = None assert self.layerwise_pruning, "layerwise pruning is forced"
[docs] @classmethod def from_config(cls, config: Config): return cls( config.sparsity, config.starting_epoch, config.frequency, config.block_size, config.columnwise_blocking, config.accumulate_mask, config.layerwise_pruning, )
[docs] def get_sparsifiable_params(self, model, requires_name=False): sparsifiable_params = [ p for n, p in model.named_parameters() if p.requires_grad and len(p.shape) == 2 ] sparsifiable_params_name = [ n for n, p in model.named_parameters() if p.requires_grad and len(p.shape) == 2 ] if requires_name: return sparsifiable_params_name, sparsifiable_params else: return sparsifiable_params
[docs] def get_current_sparsity(self, model): sparsifiable_params = self.get_sparsifiable_params(model) sparsifiable_params_count = sum(p.numel() for p in sparsifiable_params) nonzero_params = sum(p.nonzero().size(0) for p in sparsifiable_params) return (sparsifiable_params_count - nonzero_params) / sparsifiable_params_count
def _padding_into_full_blocks(self, param): nrows, ncols = param.shape ncols_pad = math.ceil(ncols / self.block_size) * self.block_size padded_param = param.new_zeros((nrows, ncols_pad)) padded_param[:nrows, :ncols] = param return padded_param def _num_blocks_kept(self, param, mask): if mask is None: mask = param.new_ones(param.shape) unpruned_param_sz = torch.nonzero(mask).size(0) max_num_nonzeros = math.ceil(unpruned_param_sz * (1 - self.sparsity)) return math.ceil(max_num_nonzeros / self.block_size) def _compute_param_mask( self, param: torch.Tensor, pre_mask: torch.Tensor = None, columnwise_blocking: bool = False, ): if columnwise_blocking: return self._compute_param_mask( param.transpose(1, 0), pre_mask=(pre_mask.transpose(1, 0) if pre_mask else None), ).transpose(1, 0) padded_param = self._padding_into_full_blocks(param) if pre_mask is not None: padded_mask = self._padding_into_full_blocks(pre_mask) padded_param.data = padded_param.data * padded_mask block_l1norms = ( torch.abs(padded_param).reshape(-1, 1, self.block_size).sum(dim=2) ) max_num_blocks = self._num_blocks_kept(param, pre_mask) topk_threshold = ( torch.topk(block_l1norms.flatten(), max_num_blocks).values.min().item() ) mask = ( block_l1norms.repeat(1, 1, self.block_size).reshape(padded_param.shape) >= topk_threshold ).to(param.dtype) if pre_mask is None: return mask[: param.size(0), : param.size(1)] else: return mask[: param.size(0), : param.size(1)] * pre_mask
[docs] def get_masks( self, model: nn.Module, pre_masks: List[torch.Tensor] = None ) -> List[torch.Tensor]: learnableparams = self.get_sparsifiable_params(model) if pre_masks: self._masks = pre_masks if self._masks: assert len(learnableparams) == len( self._masks ), "parameter dimension and mask dimension does not match" for m, w in zip(self._masks, learnableparams): # check only for non-empty mask if len(m.size()): assert ( m.size() == w.size() ), "parameter dimension and mask dimension does not match" if self._masks is not None: # sparsifying 2D tensor only, skip mask for unlearnable # and unsparsifierable param masks = [ self._compute_param_mask(p, m, self.columnwise_blocking) if len(p.shape) == 2 and p.requires_grad else p.new_empty(()) for p, m in zip(learnableparams, self._masks) ] else: # sparsifying 2D tensor only, skip mask for unlearnable # and unsparsifierable param masks = [ self._compute_param_mask( p, columnwise_blocking=self.columnwise_blocking ) if len(p.shape) == 2 and p.requires_grad else p.new_empty(()) for p in learnableparams ] if self.accumulate_mask: self._masks = masks return masks