All Projects → lmbxmu → FilterSketch

lmbxmu / FilterSketch

Licence: other
Pytorch implementation of our paper accepted by IEEE TNNLS, 2021 -- Filter Sketch for Network Pruning

Programming Languages

python
139335 projects - #7 most used programming language

Filter Sketch for Network Pruning (Link).

Pruning neural network model via filter sketch.

Tips

Any problem, free to contact the authors via emails: [email protected] or [email protected]. Do not post issues with github as much as possible, just in case that I could not receive the emails from github thus ignore the posted issues.

Citation

If you find FilterSketch useful in your research, please consider citing:

@article{lin2020filter,
  title={Filter Sketch for Network Pruning},
  author={Lin, Mingbao and Ji, Rongrong and Li, Shaojie and Ye, Qixiang and Tian, Yonghong and Liu, Jianzhuang and Tian, Qi},
  journal={arXiv preprint arXiv:2001.08514},
  year={2020}
}

Pre-trained Models

We provide the pre-trained models used in our paper.

CIFAR-10

| ResNet56 | ResNet110 |GoogLeNet |

ImageNet

| ResNet50 |

Result Models

We provide our pruned models in the experiments, along with their training loggers and configurations.

DataSet Sketch Rate Flops
(Prune Rate)
Params
(Prune Rate)
Top-1 Accuracy Top-5 Accuracy Download
ResNet56 CIFAR-10 [0.6]*27 73.36M(41.5%) 0.50M(41.2%) 93.19% - Link
ResNet110 CIFAR-10 [0.9]*3+[0.4]*24+[0.3]*24+[0.9]*3 92.84M(63.3%) 0.69M(59.9%) 93.44% - Link
GoogLeNet CIFAR-10 [0.25]*9 0.59B(61.1%) 2.61M(57.6%) 94.88% - Link
ResNet50 ImageNet [0.2]*16 0.93B(77.3%) 7.18M(71.8%) 69.43% 89.23% Link
ResNet50 ImageNet [0.4]*16 1.51B(63.1%) 10.40M(59.2%) 73.04% 91.18% Link
ResNet50 ImageNet [0.6]*16 2.23B(45.5%) 14.53M(43.0%) 74.68% 92.17% Link
ResNet50 ImageNet [0.7]*16 2.64B(35.5%) 16.95M(33.5%) 75.22% 92.41% Link

Performance of FilterSketch using ResNet-56 under different compression rates.

DataSet Sketch Rate Flops
(Prune Rate)
Params
(Prune Rate)
Top-1 Accuracy Download
CIFAR-10 [0.1]*27 11.43M(91.0%) 0.08M(90.45%) 87.38% Link
CIFAR-10 [0.2]*27 24.54M(80.6%) 0.16M(81.0%) 90.19% Link
CIFAR-10 [0.3]*27 35.61M(71.9%) 0.25M(70.6%) 91.65% Link
CIFAR-10 [0.4]*27 48.72M(61.5%) 0.33M(61.1%) 92.00% Link
CIFAR-10 [0.5]*27 63.78M(49.6%) 0.43M(49.8%) 92.29% Link
CIFAR-10 [0.6]*27 73.36M(41.5%) 0.50M(41.2%) 93.19% Link
CIFAR-10 [0.7]*27 87.31M(31.0%) 0.59M(31.1%) 93.36% Link
CIFAR-10 [0.8]*27 98.40M(22.3%) 0.68M(20.8%) 93.40% Link
CIFAR-10 [0.9]*27 111.5M(11.9%) 0.75M(11.3%) 93.44% Link
CIFAR-10 [0.9]*3+[0.1]*10+[0.1]*10+[0.6]*4 32.47M(74.4%) 0.24M(71.8%) 91.20% Link
CIFAR-10 [0.7]*3+[0.4]*10+[0.4]*10+[0.9]*4 62.63M(50.5%) 0.48M(43.3%) 92.94% Link
CIFAR-10 [0.8]*3+[0.5]*10+[0.8]*10+[0.9]*4 88.05M(30.4%) 0.68M(20.6%) 93.65% Link

Running Code

The code has been tested using Pytorch1.3 and CUDA10.0 on Ubuntu16.04.

Filter Sketch

You can run the following code to sketch model on Cifar-10:

python sketch_cifar.py 
--data_set cifar10 
--data_path ../data/cifar10/
--sketch_model ./experiment/pretrain/resnet56.pt 
--job_dir ./experiment/resnet56/sketch/
--arch resnet 
--cfg resnet56 
--lr 0.01
--lr_decay_step 50 100
--num_epochs 150 
--gpus 0
--sketch_rate [0.6]*27
--weight_norm_method l2

You can run the following code to sketch model on Imagenet:

python sketch_imagenet.py 
--data_set imagenet 
--data_path ../data/imagenet/
--sketch_model ./experiment/pretrain/resnet50.pth 
--job_dir ./experiment/resnet50/sketch/
--arch resnet 
--cfg resnet50 
--lr 0.1
--lr_decay_step 30 60
--num_epochs 90 
--gpus 0
--sketch_rate [0.6]*16
--weight_norm_method l2

Test Our Performance

Follow the command below to verify our pruned models:

python test.py 
--data_set cifar10 
--data_path ../data/cifar10 
--arch resnet 
--cfg resnet56 
--sketch_model ./experiment/result/sketch_resnet56.pt 
--sketch_rate [0.6]*27 
--gpus 0

Get FLOPS and Params

You can use the following command to install the thop python package when you need to calculate the flops of the model:

pip install thop
python get_flops_params.py 
--data_set cifar10 
--input_image_size 32 
--arch resnet 
--cfg resnet56
--sketch_rate [0.6]*27

Remarks

The number of pruning rates required for different networks is as follows:

CIFAR-10 ImageNet
ResNet56 27 -
ResNet110 54 -
GoogLeNet 9 -
ResNet50 - 16

Other Arguments

optional arguments:
  -h, --help            show this help message and exit
  --gpus GPUS [GPUS ...]
                        Select gpu_id to use. default:[0]
  --data_set DATA_SET   Select dataset to train. default:cifar10
  --data_path DATA_PATH
                        The dictionary where the input is stored.
                        default:/home/lishaojie/data/cifar10/
  --job_dir JOB_DIR     The directory where the summaries will be stored.
                        default:./experiments
  --arch ARCH           Architecture of model. default:resnet
  --cfg CFG             Detail architecuture of model. default:resnet56
  --num_epochs NUM_EPOCHS
                        The num of epochs to train. default:150
  --train_batch_size TRAIN_BATCH_SIZE
                        Batch size for training. default:128
  --eval_batch_size EVAL_BATCH_SIZE
                        Batch size for validation. default:100
  --momentum MOMENTUM   Momentum for MomentumOptimizer. default:0.9
  --lr LR               Learning rate for train. default:1e-2
  --lr_decay_step LR_DECAY_STEP [LR_DECAY_STEP ...]
                        the iterval of learn rate. default:50, 100
  --weight_decay WEIGHT_DECAY
                        The weight decay of loss. default:5e-4
  --start_conv START_CONV
                        The index of Conv to start sketch, index starts from
                        0. default:1
  --sketch_rate SKETCH_RATE
                        The proportion of each layer reserved after sketching
                        convolution layer sketch. default:None
  --sketch_model SKETCH_MODEL
                        Path to the model wait for sketch. default:None
  --weight_norm_method WEIGHT_NORM_METHOD
                        Select the weight norm method. default:None
                        Optional:l2
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].