#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
from typing import List
import torch
import torch.nn as nn
from pytext.common.constants import Stage
from pytext.config import ConfigBase
from pytext.config.component import Component, ComponentType
from pytext.models.crf import CRF
from pytext.models.model import Model
[docs]class Sparsifier(Component):
__COMPONENT_TYPE__ = ComponentType.SPARSIFIER
__EXPANSIBLE__ = True
[docs] class Config(ConfigBase):
pass
[docs] def sparsify(self, *args, **kwargs):
pass
[docs] def sparsification_condition(self, *args, **kwargs):
pass
[docs] def get_sparsifiable_params(self, *args, **kwargs):
pass
[docs] def get_current_sparsity(self, model: Model) -> float:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
nonzero_params = sum(
p.nonzero().size(0) for p in model.parameters() if p.requires_grad
)
return (trainable_params - nonzero_params) / trainable_params
[docs]class L0_projection_sparsifier(Sparsifier):
"""
L0 projection-based (unstructured) sparsification
Args:
weights (torch.Tensor): input weight matrix
sparsity (float32): the desired sparsity [0-1]
"""
[docs] class Config(Sparsifier.Config):
sparsity: float = 0.9
starting_epoch: int = 2
frequency: int = 1
layerwise_pruning: bool = True
accumulate_mask: bool = False
def __init__(
self,
sparsity,
starting_epoch,
frequency,
layerwise_pruning=True,
accumulate_mask=False,
):
assert 0 <= sparsity <= 1
self.sparsity = sparsity
assert starting_epoch >= 1
self.starting_epoch = starting_epoch
assert frequency >= 1
self.frequency = frequency
self.layerwise_pruning = layerwise_pruning
self.accumulate_mask = accumulate_mask
self._masks = None
[docs] @classmethod
def from_config(cls, config: Config):
return cls(
config.sparsity,
config.starting_epoch,
config.frequency,
config.layerwise_pruning,
config.accumulate_mask,
)
[docs] def sparsification_condition(self, state):
return (
state.stage == Stage.TRAIN
and state.epoch >= self.starting_epoch
and state.step_counter % self.frequency == 0
)
[docs] def sparsify(self, state):
"""
obtain a mask and apply the mask to sparsify
"""
model = state.model
# compute new mask when conditions are True
if self.sparsification_condition(state):
masks = self.get_masks(model)
# applied the computed mask, self.accumulate_mask handled separately
if not self.accumulate_mask:
self.apply_masks(model, masks)
# if self.accumulate_mask is True, apply the existent mask irregardless Stage
if self.accumulate_mask and self._masks is not None:
self.apply_masks(model, self._masks)
[docs] def get_sparsifiable_params(self, model: Model):
sparsifiable_params = [p for p in model.parameters() if p.requires_grad]
return sparsifiable_params
[docs] def apply_masks(self, model: Model, masks: List[torch.Tensor]):
"""
apply given masks to zero-out learnable weights in model
"""
learnableparams = self.get_sparsifiable_params(model)
assert len(learnableparams) == len(masks)
for m, w in zip(masks, learnableparams):
if len(m.size()):
assert m.size() == w.size()
w.data *= m.clone()
# if accumulate_mask, remove a param permanently by also removing
# its gradient
if self.accumulate_mask:
w.grad.data *= m.clone()
[docs] def get_masks(
self, model: Model, pre_masks: List[torch.Tensor] = None
) -> List[torch.Tensor]:
"""
Note: this function returns the masks only but do not sparsify or modify the
weights
prune x% of weights among the weights with "1" in pre_masks
Args:
model: Model
pre_masks: list of FloatTensors where "1" means retained the weight and
"0" means pruned the weight
Return:
masks: List[torch.Tensor], intersection of new masks and pre_masks, so
that "1" only if the weight is selected after new masking and pre_mask
"""
learnableparams = self.get_sparsifiable_params(model)
if pre_masks:
self._masks = pre_masks
if self._masks is None:
# retain everything if no pre_masks given
self._masks = [torch.ones_like(p) for p in learnableparams]
assert len(learnableparams) == len(self._masks)
for m, w in zip(self._masks, learnableparams):
if len(m.size()):
assert m.size() == w.size()
if self.layerwise_pruning:
masks = []
for m, param in zip(self._masks, learnableparams):
weights_abs = torch.abs(param.data).to(param.device)
# absolute value of weights selected from existent masks
weights_abs_masked_flat = torch.flatten(weights_abs[m.bool()])
total_size = weights_abs_masked_flat.numel()
if total_size > 0:
# using ceil instead of floor() or int()
# because at least one element in the tensor required to be selected
max_num_nonzeros = math.ceil(total_size * (1 - self.sparsity))
# only pruned among the weights slected from existent masks
topkval = (
torch.topk(weights_abs_masked_flat, max_num_nonzeros)
.values.min()
.item()
)
# intersection of the new mask and pre_mexistent masks,
# mask == 1 retain, mask == 0 pruned,
mask = (weights_abs >= topkval).float() * m
else:
mask = param.new_empty(())
masks.append(mask)
else:
# concatenated flatten tensor of learnableparams that have _masks as True
learnableparams_masked_flat = torch.cat(
[
torch.flatten(p[m.bool()])
for m, p in zip(self._masks, learnableparams)
],
dim=0,
)
# using ceil instead of floor() or int() because at least one element
# in the tensor required to be selected
max_num_nonzeros = math.ceil(
learnableparams_masked_flat.numel() * (1 - self.sparsity)
)
# select globally the top-k th weight among weights selected from _masks
topkval = (
torch.topk(torch.abs(learnableparams_masked_flat), max_num_nonzeros)
.values.min()
.item()
)
# intersection of the new mask and _masks,
# mask == 1 retain, mask == 0 pruned,
masks = [
(torch.abs(p.data) >= topkval).float() * m
if p.numel() > 0
else p.new_empty(())
for m, p in zip(self._masks, learnableparams)
]
if self.accumulate_mask:
self._masks = masks
return masks
[docs]class CRF_SparsifierBase(Sparsifier):
[docs] class Config(Sparsifier.Config):
starting_epoch: int = 1
frequency: int = 1
[docs] def sparsification_condition(self, state):
if state.stage == Stage.TRAIN:
return False
return (
state.epoch >= self.starting_epoch
and state.step_counter % self.frequency == 0
)
[docs] def get_sparsifiable_params(self, model: nn.Module):
for m in model.modules():
if isinstance(m, CRF):
return m.transitions.data
[docs] def get_transition_sparsity(self, transition):
nonzero_params = transition.nonzero().size(0)
return (transition.numel() - nonzero_params) / transition.numel()
[docs]class CRF_L1_SoftThresholding(CRF_SparsifierBase):
"""
implement l1 regularization:
min Loss(x, y, CRFparams) + lambda_l1 * ||CRFparams||_1
and solve the optimiation problem via (stochastic) proximal gradient-based
method i.e., soft-thresholding
param_updated = sign(CRFparams) * max ( abs(CRFparams) - lambda_l1, 0)
"""
[docs] class Config(CRF_SparsifierBase.Config):
lambda_l1: float = 0.001
def __init__(self, lambda_l1: float, starting_epoch: int, frequency: int):
self.lambda_l1 = lambda_l1
assert starting_epoch >= 1
self.starting_epoch = starting_epoch
assert frequency >= 1
self.frequency = frequency
[docs] @classmethod
def from_config(cls, config: Config):
return cls(config.lambda_l1, config.starting_epoch, config.frequency)
[docs] def sparsify(self, state):
if not self.sparsification_condition(state):
return
model = state.model
transition_matrix = self.get_sparsifiable_params(model)
transition_matrix_abs = torch.abs(transition_matrix)
assert (
len(state.optimizer.param_groups) == 1
), "different learning rates for multiple param groups not supported"
lrs = state.optimizer.param_groups[0]["lr"]
threshold = self.lambda_l1 * lrs
transition_matrix = torch.sign(transition_matrix) * torch.max(
(transition_matrix_abs - threshold),
transition_matrix.new_zeros(transition_matrix.shape),
)
current_sparsity = self.get_transition_sparsity(transition_matrix)
print(f"sparsity of CRF transition matrix: {current_sparsity}")
[docs]class CRF_MagnitudeThresholding(CRF_SparsifierBase):
"""
magnitude-based (equivalent to projection onto l0 constraint set) sparsification
on CRF transition matrix. Preserveing the top-k elements either rowwise or
columnwise until sparsity constraint is met.
"""
[docs] class Config(CRF_SparsifierBase.Config):
sparsity: float = 0.9
grouping: str = "row"
def __init__(self, sparsity, starting_epoch, frequency, grouping):
assert 0 <= sparsity <= 1
self.sparsity = sparsity
assert starting_epoch >= 1
self.starting_epoch = starting_epoch
assert frequency >= 1
self.frequency = frequency
assert (
grouping == "row" or grouping == "column"
), "grouping needs to be row or column"
self.grouping = grouping
[docs] @classmethod
def from_config(cls, config: Config):
return cls(
config.sparsity, config.starting_epoch, config.frequency, config.grouping
)
[docs] def sparsify(self, state):
if not self.sparsification_condition(state):
return
model = state.model
transition_matrix = self.get_sparsifiable_params(model)
num_rows, num_cols = transition_matrix.shape
trans_abs = torch.abs(transition_matrix)
if self.grouping == "row":
max_num_nonzeros = math.ceil(num_cols * (1 - self.sparsity))
topkvals = (
torch.topk(trans_abs, k=max_num_nonzeros, dim=1)
.values.min(dim=1, keepdim=True)
.values
)
else:
max_num_nonzeros = math.ceil(num_rows * (1 - self.sparsity))
topkvals = (
torch.topk(trans_abs, k=max_num_nonzeros, dim=0)
.values.min(dim=0, keepdim=True)
.values
)
# trans_abs < topkvals is a broadcasted comparison
transition_matrix[trans_abs < topkvals] = 0.0
current_sparsity = self.get_transition_sparsity(transition_matrix)
print(f"sparsity of CRF transition matrix: {current_sparsity}")