All Projects → xgastaldi → Shake Shake

xgastaldi / Shake Shake

Licence: bsd-3-clause
2.86% and 15.85% on CIFAR-10 and CIFAR-100

Programming Languages

lua
6591 projects

Projects that are alternatives of or similar to Shake Shake

Pyramidnet
Torch implementation of the paper "Deep Pyramidal Residual Networks" (https://arxiv.org/abs/1610.02915).
Stars: ✭ 121 (-58.42%)
Mutual labels:  resnet, torch7
dbcollection
A collection of popular datasets for deep learning.
Stars: ✭ 26 (-91.07%)
Mutual labels:  torch7
pytorch2keras
PyTorch to Keras model convertor
Stars: ✭ 788 (+170.79%)
Mutual labels:  resnet
Retinal-Disease-Diagnosis-With-Residual-Attention-Networks
Using Residual Attention Networks to diagnose retinal diseases in medical images
Stars: ✭ 14 (-95.19%)
Mutual labels:  resnet
Keras-CIFAR10
practice on CIFAR10 with Keras
Stars: ✭ 25 (-91.41%)
Mutual labels:  resnet
alpha-zero
AlphaZero implementation for Othello, Connect-Four and Tic-Tac-Toe based on "Mastering the game of Go without human knowledge" and "Mastering Chess and Shogi by Self-Play with a General Reinforcement Learning Algorithm" by DeepMind.
Stars: ✭ 68 (-76.63%)
Mutual labels:  resnet
resnet.torch
an updated version of fb.resnet.torch with many changes.
Stars: ✭ 35 (-87.97%)
Mutual labels:  resnet
Pytorch Image Classification
Tutorials on how to implement a few key architectures for image classification using PyTorch and TorchVision.
Stars: ✭ 272 (-6.53%)
Mutual labels:  resnet
jax-resnet
Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
Stars: ✭ 61 (-79.04%)
Mutual labels:  resnet
Faster-RCNN-TensorFlow
TensorFlow implementation of Faster RCNN for Object Detection
Stars: ✭ 13 (-95.53%)
Mutual labels:  resnet
DE resnet unet hyb
Depth estimation from RGB images using fully convolutional neural networks.
Stars: ✭ 40 (-86.25%)
Mutual labels:  resnet
Gradient-Samples
Samples for TensorFlow binding for .NET by Lost Tech
Stars: ✭ 53 (-81.79%)
Mutual labels:  resnet
i3d-tensorflow
Inflated 3D ConvNets for video understanding
Stars: ✭ 46 (-84.19%)
Mutual labels:  resnet
pyro-vision
Computer vision library for wildfire detection
Stars: ✭ 33 (-88.66%)
Mutual labels:  resnet
Grad Cam Tensorflow
tensorflow implementation of Grad-CAM (CNN visualization)
Stars: ✭ 261 (-10.31%)
Mutual labels:  resnet
cifar-tensorflow
No description or website provided.
Stars: ✭ 18 (-93.81%)
Mutual labels:  resnet
wideresnet-tensorlayer
Wide Residual Networks implemented in TensorLayer and TensorFlow.
Stars: ✭ 44 (-84.88%)
Mutual labels:  resnet
flexible-yolov5
More readable and flexible yolov5 with more backbone(resnet, shufflenet, moblienet, efficientnet, hrnet, swin-transformer) and (cbam,dcn and so on), and tensorrt
Stars: ✭ 282 (-3.09%)
Mutual labels:  resnet
Awesome Computer Vision Models
A list of popular deep learning models related to classification, segmentation and detection problems
Stars: ✭ 278 (-4.47%)
Mutual labels:  resnet
Resnetcam Keras
Keras implementation of a ResNet-CAM model
Stars: ✭ 269 (-7.56%)
Mutual labels:  resnet

Shake-Shake regularization

This repository contains the code for the paper Shake-Shake regularization. This arxiv paper is an extension of Shake-Shake regularization of 3-branch residual networks which was accepted as a workshop contribution at ICLR 2017.

The code is based on fb.resnet.torch.

Table of Contents

  1. Introduction
  2. Results
  3. Usage
  4. Contact

Introduction

The method introduced in this paper aims at helping deep learning practitioners faced with an overfit problem. The idea is to replace, in a multi-branch network, the standard summation of parallel branches with a stochastic affine combination. Applied to 3-branch residual networks, shake-shake regularization improves on the best single shot published results on CIFAR-10 and CIFAR-100 by reaching test errors of 2.86% and 15.85%.

shake-shake

Figure 1: Left: Forward training pass. Center: Backward training pass. Right: At test time.

Bibtex:

@article{Gastaldi17ShakeShake,
   title = {Shake-Shake regularization},
   author = {Xavier Gastaldi},
   journal = {arXiv preprint arXiv:1705.07485},
   year = 2017,
}

Results on CIFAR-10

The base network is a 26 2x32d ResNet (i.e. the network has a depth of 26, 2 residual branches and the first residual block has a width of 32). "Shake" means that all scaling coefficients are overwritten with new random numbers before the pass. "Even" means that all scaling coefficients are set to 0.5 before the pass. "Keep" means that we keep, for the backward pass, the scaling coefficients used during the forward pass. "Batch" means that, for each residual block, we apply the same scaling coefficient for all the images in the mini-batch. "Image" means that, for each residual block, we apply a different scaling coefficient for each image in the mini-batch. The numbers in the Table below represent the average of 3 runs except for the 96d models which were run 5 times.

Forward Backward Level 26 2x32d 26 2x64d 26 2x96d 26 2x112d
Even Even n\a 4.27 3.76 3.58 -
Even Shake Batch 4.44 - -
Shake Keep Batch 4.11 - - -
Shake Even Batch 3.47 3.30 - -
Shake Shake Batch 3.67 3.07 - -
Even Shake Image 4.11 - - -
Shake Keep Image 4.09 - - -
Shake Even Image 3.47 3.20 - -
Shake Shake Image 3.55 2.98 2.86 2.821

Table 1: Error rates (%) on CIFAR-10 (Top 1 of the last epoch)

Other results

Cifar-100:
29 2x4x64d: 15.85%

Reduced CIFAR-10:
26 2x96d: 17.05%1

SVHN:
26 2x96d: 1.4%1

Reduced SVHN:
26 2x96d: 12.32%1

Usage

  1. Install fb.resnet.torch, optnet and lua-stdlib.
  2. Download Shake-Shake
git clone https://github.com/xgastaldi/shake-shake.git
  1. Copy the elements in the shake-shake folder and paste them in the fb.resnet.torch folder. This will overwrite 5 files (main.lua, train.lua, opts.lua, checkpoints.lua and models/init.lua) and add 4 new files (models/shakeshake.lua, models/shakeshakeblock.lua, models/mulconstantslices.lua and models/shakeshaketable.lua).
  2. To reproduce CIFAR-10 results (e.g. 26 2x32d "Shake-Shake-Image" ResNet) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true

To get comparable results using 1 GPU, please change the batch size and the corresponding learning rate:

CUDA_VISIBLE_DEVICES=0 th main.lua -dataset cifar10 -nGPU 1 -batchSize 64 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.1 -forwardShake true -backwardShake true -shakeImage true

A 26 2x96d "Shake-Shake-Image" ResNet can be trained on 2 GPUs using:

CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 96 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true
  1. To reproduce CIFAR-100 results (e.g. 29 2x4x64d "Shake-Even-Image" ResNeXt) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar100 -depth 29 -baseWidth 64 -groups 4 -weightDecay 5e-4 -batchSize 32 -netType shakeshake -nGPU 2 -LR 0.025 -nThreads 8 -shareGradInput true -nEpochs 1800 -lrShape cosine -forwardShake true -backwardShake false -shakeImage true

Note

Changes made to fb.resnet.torch files:

main.lua
Ln 17, 54-59, 81-100: Adds a log

train.lua
Ln 36-38 58-60 206-213: Adds the cosine learning rate function
Ln 88-89: Adds the learning rate to the elements printed on screen

opts.lua
Ln 21-64: Adds Shake-Shake options

checkpoints.lua
Ln 15-16: Adds require 'models/shakeshakeblock', 'models/shakeshaketable' and require 'std'
Ln 60-61: Avoids using the fb.resnet.torch deepcopy (it doesn't seem to be compatible with the BN in shakeshakeblock) and replaces it with the deepcopy from stdlib
Ln 67-86: Saves only the last model

models/init.lua
Ln 91-92: Adds require 'models/mulconstantslices', require 'models/shakeshakeblock' and require 'models/shakeshaketable'

The main model is in shakeshake.lua. The residual block model is in shakeshakeblock.lua. mulconstantslices.lua is just an extension of nn.mulconstant that multiplies elements of a vector with image slices of a mini-batch tensor. shakeshaketable.lua contains the method used for CIFAR-100 since the ResNeXt code uses a table implementation instead of a module version.

Reimplementations

Pytorch
https://github.com/hysts/pytorch_shake_shake

Tensorflow
https://github.com/tensorflow/models/blob/master/research/autoaugment/
https://github.com/tensorflow/tensor2tensor

Contact

xgastaldi.mba2011 at london.edu
Any discussions, suggestions and questions are welcome!

References

(1) Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V. Le. AutoAugment: Learning Augmentation Policies from Data. In arXiv:1805.09501, May 2018.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].