All Projects → orobix → Prototypical Networks For Few Shot Learning Pytorch

orobix / Prototypical Networks For Few Shot Learning Pytorch

Licence: mit
Implementation of Prototypical Networks for Few Shot Learning (https://arxiv.org/abs/1703.05175) in Pytorch

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Prototypical Networks For Few Shot Learning Pytorch

Lstm Fcn
Codebase for the paper LSTM Fully Convolutional Networks for Time Series Classification
Stars: ✭ 482 (-27.95%)
Mutual labels:  cnn
Deeplearning
深度学习入门教程, 优秀文章, Deep Learning Tutorial
Stars: ✭ 6,783 (+913.9%)
Mutual labels:  cnn
Multi Class Text Classification Cnn Rnn
Classify Kaggle San Francisco Crime Description into 39 classes. Build the model with CNN, RNN (GRU and LSTM) and Word Embeddings on Tensorflow.
Stars: ✭ 570 (-14.8%)
Mutual labels:  cnn
Regl Cnn
Digit recognition with Convolutional Neural Networks in WebGL
Stars: ✭ 490 (-26.76%)
Mutual labels:  cnn
Music recommender
Music recommender using deep learning with Keras and TensorFlow
Stars: ✭ 528 (-21.08%)
Mutual labels:  cnn
How To Learn Deep Learning
A top-down, practical guide to learn AI, Deep learning and Machine Learning.
Stars: ✭ 544 (-18.68%)
Mutual labels:  cnn
Liteflownet
LiteFlowNet: A Lightweight Convolutional Neural Network for Optical Flow Estimation, CVPR 2018 (Spotlight paper, 6.6%)
Stars: ✭ 474 (-29.15%)
Mutual labels:  cnn
Cvnd exercises
Exercise notebooks for CVND.
Stars: ✭ 622 (-7.03%)
Mutual labels:  cnn
Textclassificationbenchmark
A Benchmark of Text Classification in PyTorch
Stars: ✭ 534 (-20.18%)
Mutual labels:  cnn
Flashtorch
Visualization toolkit for neural networks in PyTorch! Demo -->
Stars: ✭ 561 (-16.14%)
Mutual labels:  cnn
Demon
DeMoN: Depth and Motion Network
Stars: ✭ 501 (-25.11%)
Mutual labels:  cnn
Cnn Facial Landmark
Training code for facial landmark detection based on deep convolutional neural network.
Stars: ✭ 516 (-22.87%)
Mutual labels:  cnn
See
Code for the AAAI 2018 publication "SEE: Towards Semi-Supervised End-to-End Scene Text Recognition"
Stars: ✭ 545 (-18.54%)
Mutual labels:  cnn
Paddlepaddle code
用PaddlePaddle和Tensorflow实现常用的深度学习算法
Stars: ✭ 485 (-27.5%)
Mutual labels:  cnn
Yolov3
Keras implementation of yolo v3 object detection.
Stars: ✭ 585 (-12.56%)
Mutual labels:  cnn
Ssr Net
[IJCAI18] SSR-Net: A Compact Soft Stagewise Regression Network for Age Estimation
Stars: ✭ 475 (-29%)
Mutual labels:  cnn
Video Classification
Tutorial for video classification/ action recognition using 3D CNN/ CNN+RNN on UCF101
Stars: ✭ 543 (-18.83%)
Mutual labels:  cnn
Mvision
机器人视觉 移动机器人 VS-SLAM ORB-SLAM2 深度学习目标检测 yolov3 行为检测 opencv PCL 机器学习 无人驾驶
Stars: ✭ 6,140 (+817.79%)
Mutual labels:  cnn
Cnn For Image Retrieval
🌅The code of post "Image retrieval using MatconvNet and pre-trained imageNet"
Stars: ✭ 597 (-10.76%)
Mutual labels:  cnn
Pytorch Adain
Unofficial pytorch implementation of 'Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization' [Huang+, ICCV2017]
Stars: ✭ 550 (-17.79%)
Mutual labels:  cnn

Prototypical Networks for Few shot Learning in PyTorch

Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code) in PyTorch.

Prototypical Networks

As shown in the reference paper Prototypical Networks are trained to embed samples features in a vectorial space, in particular, at each episode (iteration), a number of samples for a subset of classes are selected and sent through the model, for each subset of class c a number of samples' features (n_support) are used to guess the prototype (their barycentre coordinates in the vectorial space) for that class, so then the distances between the remaining n_query samples and their class barycentre can be minimized.

Prototypical Networks

T-SNE

After training, you can compute the t-SNE for the features generated by the model (not done in this repo, more infos about t-SNE here), this is a sample as shown in the paper.

Reference Paper t-SNE

Omniglot Dataset

Kudos to @ludc for his contribute: https://github.com/pytorch/vision/pull/46. We will use the official dataset when it will be added to torchvision if it doesn't imply big changes to the code.

Dataset splits

We implemented the Vynials splitting method as in [Matching Networks for One Shot Learning]. That sould be the same method used in the paper (in fact I download the split files from the "offical" repo). We then apply the same rotations there described. In this way we should be able to compare results obtained by running this code with results described in the reference paper.

Prototypical Batch Sampler

As described in its PyDoc, this class is used to generate the indexes of each batch for a prototypical training algorithm.

In particular, the object is instantiated by passing the list of the labels for the dataset, the sampler infers then the total number of classes and creates a set of indexes for each class ni the dataset. At each episode the sampler selects n_classes random classes and returns a number (n_support + n_query) of samples indexes for each one of the selected classes.

Prototypical Loss

Compute the loss as in the cited paper, mostly inspired by this code by one of its authors.

In prototypical_loss.py both loss function and loss class à la PyTorch are implemented.

The function takes in input the batch input from the model, samples' ground truths and the number n_suppport of samples to be used as support samples. Episode classes get infered from the target list, n_support samples get randomly extracted for each class, their class barycentres get computed, as well as the distances of each remaining samples' embedding from each class barycentre and the probability of each sample of belonging to each episode class get finmally computed; then the loss is then computed from the wrong predictions probabilities (for the query samples) as usual in classification problems.

Training

Please note that the training code is here just for demonstration purposes.

To train the Protonet on this task, cd into this repo's src root folder and execute:

$ python train.py

The script takes the following command line options:

  • dataset_root: the root directory where tha dataset is stored, default to '../dataset'

  • nepochs: number of epochs to train for, default to 100

  • learning_rate: learning rate for the model, default to 0.001

  • lr_scheduler_step: StepLR learning rate scheduler step, default to 20

  • lr_scheduler_gamma: StepLR learning rate scheduler gamma, default to 0.5

  • iterations: number of episodes per epoch. default to 100

  • classes_per_it_tr: number of random classes per episode for training. default to 60

  • num_support_tr: number of samples per class to use as support for training. default to 5

  • num_query_tr: nnumber of samples per class to use as query for training. default to 5

  • classes_per_it_val: number of random classes per episode for validation. default to 5

  • num_support_val: number of samples per class to use as support for validation. default to 5

  • num_query_val: number of samples per class to use as query for validation. default to 15

  • manual_seed: input for the manual seeds initializations, default to 7

  • cuda: enables cuda (store True)

Running the command without arguments will train the models with the default hyperparamters values (producing results shown above).

Performances

We are trying to reproduce the reference paper performaces, we'll update here our best results.

Model 1-shot (5-way Acc.) 5-shot (5-way Acc.) 1 -shot (20-way Acc.) 5-shot (20-way Acc.)
Reference Paper 98.8% 99.7% 96.0% 98.9%
This repo 98.5%** 99.6%* 95.1%° 98.6%°°

* achieved using default parameters (using --cuda option)

** achieved running python train.py --cuda -nsTr 1 -nsVa 1

° achieved running python train.py --cuda -nsTr 1 -nsVa 1 -cVa 20

°° achieved running python train.py --cuda -nsTr 5 -nsVa 5 -cVa 20

Helpful links

.bib citation

cite the paper as follows (copied-pasted it from arxiv for you):

@article{DBLP:journals/corr/SnellSZ17,
  author    = {Jake Snell and
               Kevin Swersky and
               Richard S. Zemel},
  title     = {Prototypical Networks for Few-shot Learning},
  journal   = {CoRR},
  volume    = {abs/1703.05175},
  year      = {2017},
  url       = {http://arxiv.org/abs/1703.05175},
  archivePrefix = {arXiv},
  eprint    = {1703.05175},
  timestamp = {Wed, 07 Jun 2017 14:41:38 +0200},
  biburl    = {http://dblp.org/rec/bib/journals/corr/SnellSZ17},
  bibsource = {dblp computer science bibliography, http://dblp.org}
}

License

This project is licensed under the MIT License

Copyright (c) 2018 Daniele E. Ciriello, Orobix Srl (www.orobix.com).

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].