All Projects → titu1994 → Snapshot Ensembles

titu1994 / Snapshot Ensembles

Licence: apache-2.0
Snapshot Ensemble in Keras

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Snapshot Ensembles

paperback
Paper backup generator suitable for long-term storage.
Stars: ✭ 517 (+78.89%)
Mutual labels:  paper
sympy-paper
Repo for the paper "SymPy: symbolic computing in python"
Stars: ✭ 42 (-85.47%)
Mutual labels:  paper
Papernote
paper note, including personal comments, introduction, code etc
Stars: ✭ 268 (-7.27%)
Mutual labels:  paper
ocbnn-public
General purpose library for BNNs, and implementation of OC-BNNs in our 2020 NeurIPS paper.
Stars: ✭ 31 (-89.27%)
Mutual labels:  paper
vehicle-trajectory-prediction
Behavior Prediction in Autonomous Driving
Stars: ✭ 23 (-92.04%)
Mutual labels:  paper
GuidedLabelling
Exploiting Saliency for Object Segmentation from Image Level Labels, CVPR'17
Stars: ✭ 35 (-87.89%)
Mutual labels:  paper
fake-news-detection
This repo is a collection of AWESOME things about fake news detection, including papers, code, etc.
Stars: ✭ 34 (-88.24%)
Mutual labels:  paper
Plotsquared
PlotSquared - Reinventing the plotworld
Stars: ✭ 284 (-1.73%)
Mutual labels:  paper
Awesome-Computer-Vision-Paper-List
This repository contains all the papers accepted in top conference of computer vision, with convenience to search related papers.
Stars: ✭ 248 (-14.19%)
Mutual labels:  paper
Papyrus
📄 Unofficial Dropbox Paper desktop app
Stars: ✭ 263 (-9%)
Mutual labels:  paper
MoCo
A pytorch reimplement of paper "Momentum Contrast for Unsupervised Visual Representation Learning"
Stars: ✭ 41 (-85.81%)
Mutual labels:  paper
PublicWeaklySupervised
(Machine) Learning to Do More with Less
Stars: ✭ 13 (-95.5%)
Mutual labels:  paper
paper-reading
深度学习经典、新论文逐段精读
Stars: ✭ 6,633 (+2195.16%)
Mutual labels:  paper
Restoring-Extremely-Dark-Images-In-Real-Time
The project is the official implementation of our CVPR 2021 paper, "Restoring Extremely Dark Images in Real Time"
Stars: ✭ 79 (-72.66%)
Mutual labels:  paper
Awesome Ehr Deeplearning
Curated list of awesome papers for electronic health records(EHR) mining, machine learning, and deep learning.
Stars: ✭ 269 (-6.92%)
Mutual labels:  paper
CHyVAE
Code for our paper -- Hyperprior Induced Unsupervised Disentanglement of Latent Representations (AAAI 2019)
Stars: ✭ 18 (-93.77%)
Mutual labels:  paper
3PU pytorch
pytorch implementation of >>Patch-base progressive 3D Point Set Upsampling<<
Stars: ✭ 61 (-78.89%)
Mutual labels:  paper
Summary Of Recommender System Papers
阅读过的推荐系统论文的归类总结,持续更新中…
Stars: ✭ 288 (-0.35%)
Mutual labels:  paper
Alae
[CVPR2020] Adversarial Latent Autoencoders
Stars: ✭ 3,178 (+999.65%)
Mutual labels:  paper
Glow
Code for reproducing results in "Glow: Generative Flow with Invertible 1x1 Convolutions"
Stars: ✭ 2,859 (+889.27%)
Mutual labels:  paper

Snapshot Ensembles in Keras

Implementation of the paper Snapshot Ensembles: Train 1, Get M for Free in Keras 1.1.1

Explanation

Snapshot Ensemble is a method to obtain multiple neural network which can be ensembled at no additional training cost. This is achieved by letting a single neural network converge into several local minima along its optimization path and save the model parameters at certain epochs, therefore the weights being "snapshots" of the model.

The repeated rapid convergence is realized using cosine annealing cycles as the learning rate schedule. It can be described by:

This scheduler provides a learning rate which is similar to the below image. Note that the learning rate never actually becomes 0, it just gets very close to it (~0.0005):

The theory behind using a learning rate schedule which occilates between such extreme values (0.1 to 5e-4, M times) is that there exist multiple local minima when training a model. Constantly reducing the local learning rate can force the model to be stuck at a less than optimal local minima. Therefore, to escape, we use a very large learning rate to escape the current local minima and attempt to find another possibly better local minima.

It can be properly described using the following image:

Figure 1: Left: Illustration of SGD optimization with a typical learning rate schedule. The model converges to a minimum at the end of training. Right: Illustration of Snapshot Ensembling optimization. The model undergoes several learning rate annealing cycles, converging to and escaping from multiple local minima. We take a snapshot at each minimum for test time ensembling.

Usage

The paper uses several models such as ResNet-101, Wide Residual Network and DenseNet-40 and DenseNet-100. While DenseNets are the highest performing models in the paper, they are too large and take extremely long to train. Therefore, the current trained model is the Wide Residual Net (16-4) setting. This model performs poorly compared to the 34-4 version but trains several times faster.

The technique is simple to implement in Keras, using a custom callback. These callbacks can be built using the SnapshotCallbackBuilder class in snapshot.py. Other models can simply use this callback builder to other models to train them in a similar manner.

To use snapshot ensemble in other models :

from snapshot import SnapshotCallbackBuilder

M = 5 # number of snapshots
nb_epoch = T = 200 # number of epochs
alpha_zero = 0.1 # initial learning rate
model_prefix = 'Model_'

snapshot = SnapshotCallbackBuilder(T, M, alpha_zero) 
...
model = Sequential() OR model = Model(ip, output) # Some model that has been compiled

model.fit(trainX, trainY, callbacks=snapshot.get_callbacks(model_prefix=model_prefix))

To train WRN or DenseNet models on CIFAR 10 or 100 (or use pre trained models):

  1. Download the 6 WRN-16-4 weights that are provided in the Release tab of the project and place them in the weights directory for CIFAR 10 or 100
  2. Run the train_cifar_10.py script to train the WRN-16-4 model on CIFAR-10 dataset (not required since weights are provided)
  3. Run the predict_cifar_10.py script to make an ensemble prediction.

Note the difference on calculating only the predictions of the best model (92.70 % accuracy), and the weighted ensemble version of the Snapshots (92.84 % accuracy). The difference is minor, but still an improvement.

The improvement is minor due to the fact that the model is far smaller than the WRN-34-4 model, nor is it trained on the CIFAR-100 or Tiny ImageNet dataset. According to the paper, models trained on more complex datasets such as CIFAR 100 and Tiny ImageNet obtaines a greater boost from the ensemble model.

Parameters

Some parameters for WRN models from the paper:

  • M = 5
  • nb_epoch = 200
  • alpha_zero = 0.1
  • wrn_N = 2 (WRN-16-4) or 4 (WRN-28-8)
  • wrn_k = 4 (WRN-16-4) or 8 (WRN-28-8)

Some parameters for DenseNet models from the paper:

  • M = 6
  • nb_epoch = 300
  • alpha_zero = 0.2
  • dn_depth = 40 (DenseNet-40-12) or 100 (DenseNet-100-24)
  • dn_growth_rate = 12 (DenseNet-40-12) or 24 (DenseNet-100-24)

train_*.py

--M              : Number of snapshots that will be taken. Optimal range is in between 4 - 8. Default is 5
--nb_epoch       : Number of epochs to train the network. Default is 200
--alpha_zero     : Initial Learning Rate. Usually 0.1 or 0.2. Default is 0.1

--model          : Type of model to train. Can be "wrn" for Wide ResNets or "dn" for DenseNet

--wrn_N          : Number of WRN blocks. Computed as N = (n - 4) / 6. Default is 2.
--wrn_k          : Width factor of WRN. Default is 12.

--dn_depth       : Depth of DenseNet. Default is 40.
--dn_growth_rate : Growth rate of DenseNet. Default is 12.

predict_*.py

--optimize       : Flag to optimize the ensemble weights. 
                   Default is 0 (Predict using optimized weights).
                   Set to 1 to optimize ensemble weights (test for num_tests times).
                   Set to -1 to predict using equal weights for all models (As given in the paper).
               
--num_tests      : Number of times the optimizations will be performed. Default is 20

--model          : Type of model to train. Can be "wrn" for Wide ResNets or "dn" for DenseNet

--wrn_N          : Number of WRN blocks. Computed as N = (n - 4) / 6. Default is 2.
--wrn_k          : Width factor of WRN. Default is 12.

--dn_depth       : Depth of DenseNet. Default is 40.
--dn_growth_rate : Growth rate of DenseNet. Default is 12.

Performance

  • Single Best: Describes the performance of the single best model.
  • Without Optimization: Describes the performance of the ensemble model with equal weights for all models
  • With Optimization: Describes the performance of the ensemble model with optimized weights found via minimization of log-loss scores

Requirements

  • Keras
  • Theano (tested) / Tensorflow (not tested, weights not available but can be converted)
  • scipy
  • h5py
  • sklearn
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].