All Projects → aakhundov → tf-attend-infer-repeat

aakhundov / tf-attend-infer-repeat

Licence: MIT license
TensorFlow-based implementation of "Attend, Infer, Repeat" paper (Eslami et al., 2016, arXiv:1603.08575).

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to tf-attend-infer-repeat

Automatic speech recognition
End-to-end Automatic Speech Recognition for Madarian and English in Tensorflow
Stars: ✭ 2,751 (+6152.27%)
Mutual labels:  rnn
Market-Trend-Prediction
This is a project of build knowledge graph course. The project leverages historical stock price, and integrates social media listening from customers to predict market Trend On Dow Jones Industrial Average (DJIA).
Stars: ✭ 57 (+29.55%)
Mutual labels:  rnn
nemesyst
Generalised and highly customisable, hybrid-parallelism, database based, deep learning framework.
Stars: ✭ 17 (-61.36%)
Mutual labels:  rnn
deeplearning.ai notes
📓 Notes for Andrew Ng's courses on deep learning
Stars: ✭ 73 (+65.91%)
Mutual labels:  rnn
machine learning
机器学习、深度学习、NLP实战项目
Stars: ✭ 123 (+179.55%)
Mutual labels:  rnn
STAR Network
[PAMI 2021] Gating Revisited: Deep Multi-layer RNNs That Can Be Trained
Stars: ✭ 16 (-63.64%)
Mutual labels:  rnn
Har Stacked Residual Bidir Lstms
Using deep stacked residual bidirectional LSTM cells (RNN) with TensorFlow, we do Human Activity Recognition (HAR). Classifying the type of movement amongst 6 categories or 18 categories on 2 different datasets.
Stars: ✭ 250 (+468.18%)
Mutual labels:  rnn
Probabilistic-RNN-DA-Classifier
Probabilistic Dialogue Act Classification for the Switchboard Corpus using an LSTM model
Stars: ✭ 22 (-50%)
Mutual labels:  rnn
presidential-rnn
Project 4 for Metis bootcamp. Objective was generation of character-level RNN trained on Donald Trump's statements using Keras. Also generated Markov chains, and quick pyTorch RNN as baseline. Attempted semi-supervised GAN, but was unable to test in time.
Stars: ✭ 26 (-40.91%)
Mutual labels:  rnn
Predicting-Next-Character-using-RNN
Uses RNN on the Nietzsche dataset
Stars: ✭ 15 (-65.91%)
Mutual labels:  rnn
Online-Signature-Verification
Online Handwriting Signature Verification using CNN + RNN.
Stars: ✭ 16 (-63.64%)
Mutual labels:  rnn
Selected Stories
An experimental web text editor that runs a LSTM model while you write to suggest new lines
Stars: ✭ 39 (-11.36%)
Mutual labels:  rnn
Motor-Imagery-Tasks-Classification-using-EEG-data
Implementation of Deep Neural Networks in Keras and Tensorflow to classify motor imagery tasks using EEG data
Stars: ✭ 67 (+52.27%)
Mutual labels:  rnn
deep-learning-notes
🧠👨‍💻Deep Learning Specialization • Lecture Notes • Lab Assignments
Stars: ✭ 20 (-54.55%)
Mutual labels:  rnn
training-charRNN
Training charRNN model for ml5js
Stars: ✭ 87 (+97.73%)
Mutual labels:  rnn
Deepjazz
Deep learning driven jazz generation using Keras & Theano!
Stars: ✭ 2,766 (+6186.36%)
Mutual labels:  rnn
rnn-theano
RNN(LSTM, GRU) in Theano with mini-batch training; character-level language models in Theano
Stars: ✭ 68 (+54.55%)
Mutual labels:  rnn
keras-utility-layer-collection
Collection of custom layers and utility functions for Keras which are missing in the main framework.
Stars: ✭ 63 (+43.18%)
Mutual labels:  rnn
ACT
Alternative approach for Adaptive Computation Time in TensorFlow
Stars: ✭ 16 (-63.64%)
Mutual labels:  rnn
TCN-TF
TensorFlow Implementation of TCN (Temporal Convolutional Networks)
Stars: ✭ 107 (+143.18%)
Mutual labels:  rnn

Attend, Infer, Repeat

Implementation of continuous relaxation of AIR framework proposed in "Attend, Infer, Repeat: Fast Scene Understanding with Generative Models" (Eslami et al., 2016). The work has been done in equal contributions with Alexander Prams. The model is implemented in TensorFlow.

  • multi_mnist.py needs to be run before training the model for generation of multi-MNIST dataset: 60,000 50x50-pixel images with 0, 1, or 2 random non-overlapping MNIST digits.
  • training.py is a runnable script for training the model with default hyperparameter configuration parameters (passed to the constructor of AIRModel class). While training takes place, its progress is written to the sub-folders of "air_results" folder: complete snapshot of the source code in "source", periodic model checkpoints in "model", and rich TensorBoard summaries (including attention/reconstruction samples) in "summary".
  • demo.py is a live demo of trained model's performance (using saved parameter values from model folder) that allows drawing digits in a Python GUI and attending/reconstructing them in real time.
  • embeddings.py generates TensorBoard projector summaries (in "embeddings" folder) for low-dimensional t-SNE or PCA visualization of 50-dimensional VAE latent space of attended/reconstructed digits vs. their ground truth labels.
  • air/air_model.py is extensively configurable AIRModel class, which comprises the model implementation.
  • air/transformer.py is a Spatial Transformer implementation borrowed from TensorFlow models repository.
  • model folder contains TenorFlow checkpoint with the parameter values of the model trained for 270k iterations with the default hyperparameter configuration specified in training.py. These parameter values are used in demo.py and embeddings.py.

Noisy gradients of discrete zpres (Bernoulli random variable sampled to predict the presence of another digit on a canvas: 1 meaning “yes”, 0 – “no”) caused severe stability issues in training the model. NVIL (Mnih & Gregor, 2014) was originally used to alleviate the problem of gradient noise, but it did not make the training process stable enough. Concrete (Gumbel-Softmax) random variable (Maddison et al., 2016, Jang et al., 2016) – a continuous relaxation of discrete random variable – was employed to improve training stability.

Discrete zpres was replaced by continuous analogue sampled from Concrete distribution with temperature 1.0 and taking values between 0 and 1. Correspondingly, original Bernoulli KL-divergence was replaced by MC-sample of Concrete KL-divergence. Furthermore, two additional adaptations were made. First, VAE reconstructions were scaled by zpres before being added to reconstruction canvas. This pushes continuous samples to 0 or 1 when the model wants to stop or attend to another digit respectively. Second, inspired by ACT (Graves, 2016), stopping criterion was reformulated as a running sum of (1 – zpres) values at each time step exceeding some configurable threshold (0.99 used in experiments). The threshold being less than 1 allows stopping during very first time step, which is essential for empty images that should not be attended at all. As a result, in the limit of Concrete zpres samples taking extreme values of 0 and 1 this relaxed model turns into the original AIR with discrete zpres.

After applying the continuous relaxation, 10 out of 10 training runs in a row converged towards 98% digit count accuracy in the average course of 25,000 iterations. All 10 trainings were conducted for 300 epochs (276k iterations) with the default set of hyperparameters from training.py, some of them being: 256 LSTM cells, learning rate of 10-4, gradient clipping with the global norm of 1.0, and smooth exponential decay of zpres prior log-odds from 104 to 10-9 during the first 40,000 iterations. Below charts show digit count accuracy for the entire validation set (above) and its subsets of 0-, 1-, and 2-digit (left to right) images respectively (below):

alt text alt text

The samples of attention/reconstruction made by an AIR model trained with traning.py (for each pair: original on the left, reconstruction on the right; red attention window corresponds to the first time step, green – to the second one):

alt text

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].