All Projects → IraKorshunova → bruno

IraKorshunova / bruno

Licence: MIT license
a deep recurrent model for exchangeable data

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to bruno

pomdp-baselines
Simple (but often Strong) Baselines for POMDPs in PyTorch - ICML 2022
Stars: ✭ 162 (+376.47%)
Mutual labels:  recurrent-neural-networks
GPJax
A didactic Gaussian process package for researchers in Jax.
Stars: ✭ 159 (+367.65%)
Mutual labels:  gaussian-processes
handson-ml
도서 "핸즈온 머신러닝"의 예제와 연습문제를 담은 주피터 노트북입니다.
Stars: ✭ 285 (+738.24%)
Mutual labels:  recurrent-neural-networks
One-Shot-Learning-with-Siamese-Networks
Implementation of One Shot Learning using Convolutional Siamese Networks on Omniglot Dataset
Stars: ✭ 129 (+279.41%)
Mutual labels:  omniglot
keras-ordered-neurons
Ordered Neurons LSTM
Stars: ✭ 29 (-14.71%)
Mutual labels:  recurrent-neural-networks
GPBoost
Combining tree-boosting with Gaussian process and mixed effects models
Stars: ✭ 360 (+958.82%)
Mutual labels:  gaussian-processes
sequence labeling tf
Sequence Labeling in Tensorflow
Stars: ✭ 18 (-47.06%)
Mutual labels:  recurrent-neural-networks
VariationalNeuralAnnealing
A variational implementation of classical and quantum annealing using recurrent neural networks for the purpose of solving optimization problems.
Stars: ✭ 21 (-38.24%)
Mutual labels:  recurrent-neural-networks
Music-Style-Transfer
Source code for "Transferring the Style of Homophonic Music Using Recurrent Neural Networks and Autoregressive Model"
Stars: ✭ 16 (-52.94%)
Mutual labels:  recurrent-neural-networks
fin
finance
Stars: ✭ 38 (+11.76%)
Mutual labels:  recurrent-neural-networks
ganbert
Enhancing the BERT training with Semi-supervised Generative Adversarial Networks
Stars: ✭ 205 (+502.94%)
Mutual labels:  few-shot-learning
pytorch-meta-dataset
A non-official 100% PyTorch implementation of META-DATASET benchmark for few-shot classification
Stars: ✭ 39 (+14.71%)
Mutual labels:  few-shot-learning
LSTM-Time-Series-Analysis
Using LSTM network for time series forecasting
Stars: ✭ 41 (+20.59%)
Mutual labels:  recurrent-neural-networks
mmn
Moore Machine Networks (MMN): Learning Finite-State Representations of Recurrent Policy Networks
Stars: ✭ 39 (+14.71%)
Mutual labels:  recurrent-neural-networks
bitcoin-prediction
bitcoin prediction algorithms
Stars: ✭ 21 (-38.24%)
Mutual labels:  recurrent-neural-networks
datasetsome
一些数据集处理相关的 API
Stars: ✭ 37 (+8.82%)
Mutual labels:  omniglot
Name-NationalOrigin-Classifier
Using a recurrent neural network in TensorFlow to predict national origin by last name.
Stars: ✭ 30 (-11.76%)
Mutual labels:  recurrent-neural-networks
stanford-cs231n-assignments-2020
This repository contains my solutions to the assignments for Stanford's CS231n "Convolutional Neural Networks for Visual Recognition" (Spring 2020).
Stars: ✭ 84 (+147.06%)
Mutual labels:  recurrent-neural-networks
TrendinessOfTrends
The Trendiness of Trends
Stars: ✭ 14 (-58.82%)
Mutual labels:  gaussian-processes
Reservoir
Code for Reservoir computing (Echo state network)
Stars: ✭ 40 (+17.65%)
Mutual labels:  recurrent-neural-networks

BRUNO: A Deep Recurrent Model for Exchangeable Data

This is an official code for reproducing the main results from our NIPS'18 paper:

I. Korshunova, J. Degrave, F. Huszár, Y. Gal, A. Gretton, J. Dambre
BRUNO: A Deep Recurrent Model for Exchangeable Data
arxiv.org/abs/1802.07535

and from our NIPS'18 Bayesian Deep Learning workshop paper:

I. Korshunova, Y. Gal, J. Dambre, A. Gretton
Conditional BRUNO: A Deep Recurrent Process for Exchangeable Labelled Data bayesiandeeplearning.org/2018/papers/40.pdf

Requirements

The code was used with the following settings:

  • python3
  • tensorflow-gpu==1.7.0
  • scikit-image==0.13.1
  • numpy==1.14.2
  • scipy==1.0.0

Datasets

Below we list files for every dataset that should be stored in a data/ directory inside a project folder.

MNIST

Download from yann.lecun.com/exdb/mnist/

 data/train-images-idx3-ubyte.gz
 data/train-labels-idx1-ubyte.gz
 data/t10k-images-idx3-ubyte.gz
 data/t10k-labels-idx1-ubyte.gz

Fashion MNIST

Download from github.com/zalandoresearch/fashion-mnist

data/fashion_mnist/train-images-idx3-ubyte.gz
data/fashion_mnist/train-labels-idx1-ubyte.gz
data/fashion_mnist/t10k-images-idx3-ubyte.gz
data/fashion_mnist/t10k-labels-idx1-ubyte.gz

Omniglot

Download and unzip files from github.com/brendenlake/omniglot/tree/master/python

data/images_background
data/images_evaluation

Download .pkl files from github.com/renmengye/few-shot-ssl-public#omniglot. These are used to make train-test-validation split.

data/train_vinyals_aug90.pkl
data/test_vinyals_aug90.pkl
data/val_vinyals_aug90.pkl

Run utils.py to preprocess Omniglot images

data/omniglot_x_train.npy
data/omniglot_y_train.npy
data/omniglot_x_test.npy
data/omniglot_y_test.npy
data/omniglot_valid_classes.npy

CIFAR-10

This dataset will be downloaded directly with the first call to CIFAR-10 models.

data/cifar/cifar-10-batches-py

Training and testing

There are configuration files in config_rnn for every model we used in the paper and a bunch of testing scripts. Below are examples on how to train and test Omniglot models.

Training (supports multiple gpus)

CUDA_VISIBLE_DEVICES=0,1 python3 -m config_rnn.train  --config_name bn2_omniglot_tp --nr_gpu 2

Fine-tuning (to be used on one gpu only)

CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.train_finetune  --config_name bn2_omniglot_tp_ft_1s_20w

Generating samples

CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.test_samples  --config_name bn2_omniglot_tp_ft_1s_20w

Few-shot classification

CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.test_few_shot_omniglot  --config_name bn2_omniglot_tp --seq_len 2 --batch_size 20
CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.test_few_shot_omniglot  --config_name bn2_omniglot_tp_ft_1s_20w --seq_len 2 --batch_size 20

Here, batch_size = k and seq_len = n + 1 to test the model in a k-way, n-shot setting.

Citation

Please cite our paper when using this code for your research. If you have any questions, please send me an email at [email protected]

@incollection{bruno2018,
    title = {BRUNO: A Deep Recurrent Model for Exchangeable Data},
    author = {Korshunova, Iryna and Degrave, Jonas and Huszar, Ferenc and Gal, Yarin and Gretton, Arthur and Dambre, Joni},
    booktitle = {Advances in Neural Information Processing Systems 31},
    year = {2018}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].