All Projects → VainF → Torch Pruning

VainF / Torch Pruning

Licence: mit
A pytorch pruning toolkit for structured neural network pruning and layer dependency maintaining.

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Torch Pruning

Kd lib
A Pytorch Knowledge Distillation library for benchmarking and extending works in the domains of Knowledge Distillation, Pruning, and Quantization.
Stars: ✭ 173 (-10.36%)
Mutual labels:  model-compression, pruning
Awesome Ml Model Compression
Awesome machine learning model compression research papers, tools, and learning material.
Stars: ✭ 166 (-13.99%)
Mutual labels:  model-compression, pruning
torch-model-compression
针对pytorch模型的自动化模型结构分析和修改工具集,包含自动分析模型结构的模型压缩算法库
Stars: ✭ 126 (-34.72%)
Mutual labels:  pruning, model-compression
ATMC
[NeurIPS'2019] Shupeng Gui, Haotao Wang, Haichuan Yang, Chen Yu, Zhangyang Wang, Ji Liu, “Model Compression with Adversarial Robustness: A Unified Optimization Framework”
Stars: ✭ 41 (-78.76%)
Mutual labels:  pruning, model-compression
Paddleslim
PaddleSlim is an open-source library for deep model compression and architecture search.
Stars: ✭ 677 (+250.78%)
Mutual labels:  model-compression, pruning
DS-Net
(CVPR 2021, Oral) Dynamic Slimmable Network
Stars: ✭ 204 (+5.7%)
Mutual labels:  pruning, model-compression
Regularization-Pruning
[ICLR'21] PyTorch code for our paper "Neural Pruning via Growing Regularization"
Stars: ✭ 44 (-77.2%)
Mutual labels:  pruning, model-compression
Awesome Ai Infrastructures
Infrastructures™ for Machine Learning Training/Inference in Production.
Stars: ✭ 223 (+15.54%)
Mutual labels:  model-compression, pruning
Filter Pruning Geometric Median
Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration (CVPR 2019 Oral)
Stars: ✭ 338 (+75.13%)
Mutual labels:  model-compression, pruning
Soft Filter Pruning
Soft Filter Pruning for Accelerating Deep Convolutional Neural Networks
Stars: ✭ 291 (+50.78%)
Mutual labels:  model-compression, pruning
SViTE
[NeurIPS'21] "Chasing Sparsity in Vision Transformers: An End-to-End Exploration" by Tianlong Chen, Yu Cheng, Zhe Gan, Lu Yuan, Lei Zhang, Zhangyang Wang
Stars: ✭ 50 (-74.09%)
Mutual labels:  pruning, model-compression
Awesome Pruning
A curated list of neural network pruning resources.
Stars: ✭ 1,017 (+426.94%)
Mutual labels:  model-compression, pruning
Model Optimization
A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
Stars: ✭ 992 (+413.99%)
Mutual labels:  model-compression, pruning
Micronet
micronet, a model compression and deploy lib. compression: 1、quantization: quantization-aware-training(QAT), High-Bit(>2b)(DoReFa/Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference)、Low-Bit(≤2b)/Ternary and Binary(TWN/BNN/XNOR-Net); post-training-quantization(PTQ), 8-bit(tensorrt); 2、 pruning: normal、regular and group convolutional channel pruning; 3、 group convolution structure; 4、batch-normalization fuse for quantization. deploy: tensorrt, fp32/fp16/int8(ptq-calibration)、op-adapt(upsample)、dynamic_shape
Stars: ✭ 1,232 (+538.34%)
Mutual labels:  model-compression, pruning
Collaborative Distillation
PyTorch code for our CVPR'20 paper "Collaborative Distillation for Ultra-Resolution Universal Style Transfer"
Stars: ✭ 138 (-28.5%)
Mutual labels:  model-compression
Keras compressor
Model Compression CLI Tool for Keras.
Stars: ✭ 160 (-17.1%)
Mutual labels:  model-compression
Condensa
Programmable Neural Network Compression
Stars: ✭ 129 (-33.16%)
Mutual labels:  model-compression
Pretrained Language Model
Pretrained language model and its related optimization techniques developed by Huawei Noah's Ark Lab.
Stars: ✭ 2,033 (+953.37%)
Mutual labels:  model-compression
Pruning
Code for "Co-Evolutionary Compression for Unpaired Image Translation" (ICCV 2019) and "SCOP: Scientific Control for Reliable Neural Network Pruning" (NeurIPS 2020).
Stars: ✭ 159 (-17.62%)
Mutual labels:  model-compression
Microexpnet
MicroExpNet: An Extremely Small and Fast Model For Expression Recognition From Frontal Face Images
Stars: ✭ 121 (-37.31%)
Mutual labels:  model-compression

Torch-Pruning

A pytorch toolkit for structured neural network pruning and layer dependency maintaining

This tool will automatically detect and handle layer dependencies (channel consistency) during pruning. It is able to handle various network architectures such as DenseNet, ResNet, and Inception. See examples/test_models.py for more supported models.

How it works

This package will run your model with fake inputs and collect forward information just like torch.jit. Then a dependency graph is established to describe the computational graph. When a pruning function (e.g. torch_pruning.prune_conv ) is applied on certain layer through DependencyGraph.get_pruning_plan, this package will traverse the whole graph to fix inconsistent modules such as BN. The pruning index will be automatically mapped to correct position if there is torch.split or torch.cat in your model.

Tip: please remember to save the whole model object (weights+architecture) rather than model weights only:

# save a pruned model
# torch.save(model.state_dict(), 'model.pth') # weights only
torch.save(model, 'model.pth') # obj (arch) + weights

# load a pruned model
model = torch.load('model.pth') # no load_state_dict
Dependency Visualization Example
Conv-Conv AlexNet
Conv-FC (Global Pooling or Flatten) ResNet, VGG
Skip Connection ResNet
Concatenation DenseNet, ASPP
Split torch.chunk

Known Issues:

  • When groups>1, only depthwise conv is supported, i.e. groups=in_channels=out_channels.
  • Customized operations will be treated as element-wise op, e.g. subclass of torch.autograd.Function.

Installation

pip install torch_pruning # v0.2.4

Quickstart

A minimal example

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)

# 1. setup strategy (L1 Norm)
strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy()

# 2. build layer dependency for resnet18
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 3. get a pruning plan from the dependency graph.
pruning_idxs = strategy(model.conv1.weight, amount=0.4) # or manually selected pruning_idxs=[2, 6, 9]
pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )
print(pruning_plan)

# 4. execute this plan (prune the model)
pruning_plan.exec()

Pruning the resnet.conv1 will affect several layers. Let's inspect the pruning plan (with pruning_idxs=[2, 6, 9]):

-------------
[ <DEP: prune_conv => prune_conv on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))>, Index=[2, 6, 9], NumPruned=441]
[ <DEP: prune_conv => prune_batchnorm on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[2, 6, 9], NumPruned=6]
[ <DEP: prune_batchnorm => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_batchnorm on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[2, 6, 9], NumPruned=6]
[ <DEP: prune_batchnorm => prune_conv on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_batchnorm on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[2, 6, 9], NumPruned=6]
[ <DEP: prune_batchnorm => prune_conv on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=3456]
[ <DEP: _prune_elementwise_op => prune_related_conv on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False))>, Index=[2, 6, 9], NumPruned=384]
11211 parameters will be pruned
-------------

Low-level pruning functions

We have to manually handle the broken dependencies without DependencyGraph.

tp.prune_conv( model.conv1, idxs=[2,6,9] )

# fix the broken dependencies manually
tp.prune_batchnorm( model.bn1, idxs=[2,6,9] )
tp.prune_related_conv( model.layer2[0].conv1, idxs=[2,6,9] )
...

Customized Layers

Please refer to 'examples/customize_layer.py' for pruning customized layers with this package. A detailed tutorial is on the way!

Layer Dependency

During structured pruning, we need to maintain the channel consistency between different layers.

A Simple Case

More Complicated Cases

the layer dependency becomes much more complicated when the model contains skip connections or concatenations.

Residual Block:

Concatenation:

See paper Pruning Filters for Efficient ConvNets for more details.

Example: ResNet18 on Cifar10

1. Train the model

cd examples
python prune_resnet18_cifar10.py --mode train # 11.1M, Acc=0.9248

2. Pruning and fintuning

python prune_resnet18_cifar10.py --mode prune --round 1 --total_epochs 30 --step_size 20 # 4.5M, Acc=0.9229
python prune_resnet18_cifar10.py --mode prune --round 2 --total_epochs 30 --step_size 20 # 1.9M, Acc=0.9207
python prune_resnet18_cifar10.py --mode prune --round 3 --total_epochs 30 --step_size 20 # 0.8M, Acc=0.9176
python prune_resnet18_cifar10.py --mode prune --round 4 --total_epochs 30 --step_size 20 # 0.4M, Acc=0.9102
python prune_resnet18_cifar10.py --mode prune --round 5 --total_epochs 30 --step_size 20 # 0.2M, Acc=0.9011
...
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].