All Projects → kevinzakka → Recurrent Visual Attention

kevinzakka / Recurrent Visual Attention

Licence: mit
A PyTorch Implementation of "Recurrent Models of Visual Attention"

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Recurrent Visual Attention

mtad-gat-pytorch
PyTorch implementation of MTAD-GAT (Multivariate Time-Series Anomaly Detection via Graph Attention Networks) by Zhao et. al (2020, https://arxiv.org/abs/2009.02040).
Stars: ✭ 85 (-79.47%)
Mutual labels:  attention
Keras Transformer
Transformer implemented in Keras
Stars: ✭ 273 (-34.06%)
Mutual labels:  attention
Text Classification Models Pytorch
Implementation of State-of-the-art Text Classification Models in Pytorch
Stars: ✭ 379 (-8.45%)
Mutual labels:  attention
ai challenger 2018 sentiment analysis
Fine-grained Sentiment Analysis of User Reviews --- AI CHALLENGER 2018
Stars: ✭ 16 (-96.14%)
Mutual labels:  attention
Abd Net
[ICCV 2019] "ABD-Net: Attentive but Diverse Person Re-Identification" https://arxiv.org/abs/1908.01114
Stars: ✭ 272 (-34.3%)
Mutual labels:  attention
Deepxi
Deep Xi: A deep learning approach to a priori SNR estimation implemented in TensorFlow 2/Keras. For speech enhancement and robust ASR.
Stars: ✭ 304 (-26.57%)
Mutual labels:  attention
Attention
一些不同的Attention机制代码
Stars: ✭ 17 (-95.89%)
Mutual labels:  attention
Neural sp
End-to-end ASR/LM implementation with PyTorch
Stars: ✭ 408 (-1.45%)
Mutual labels:  attention
Mcan Vqa
Deep Modular Co-Attention Networks for Visual Question Answering
Stars: ✭ 273 (-34.06%)
Mutual labels:  attention
Ner Bert
BERT-NER (nert-bert) with google bert https://github.com/google-research.
Stars: ✭ 339 (-18.12%)
Mutual labels:  attention
Abcnn
Implementation of ABCNN(Attention-Based Convolutional Neural Network) on Tensorflow
Stars: ✭ 264 (-36.23%)
Mutual labels:  attention
Encoder decoder
Four styles of encoder decoder model by Python, Theano, Keras and Seq2Seq
Stars: ✭ 269 (-35.02%)
Mutual labels:  attention
Crnn attention ocr chinese
CRNN with attention to do OCR,add Chinese recognition
Stars: ✭ 315 (-23.91%)
Mutual labels:  attention
ResUNetPlusPlus
Official code for ResUNetplusplus for medical image segmentation (TensorFlow implementation) (IEEE ISM)
Stars: ✭ 69 (-83.33%)
Mutual labels:  attention
Nlp Tutorials
Simple implementations of NLP models. Tutorials are written in Chinese on my website https://mofanpy.com
Stars: ✭ 394 (-4.83%)
Mutual labels:  attention
Attention-Visualization
Visualization for simple attention and Google's multi-head attention.
Stars: ✭ 54 (-86.96%)
Mutual labels:  attention
Seq2seq Summarizer
Pointer-generator reinforced seq2seq summarization in PyTorch
Stars: ✭ 306 (-26.09%)
Mutual labels:  attention
Pytorch Original Transformer
My implementation of the original transformer model (Vaswani et al.). I've additionally included the playground.py file for visualizing otherwise seemingly hard concepts. Currently included IWSLT pretrained models.
Stars: ✭ 411 (-0.72%)
Mutual labels:  attention
Deep learning nlp
Keras, PyTorch, and NumPy Implementations of Deep Learning Architectures for NLP
Stars: ✭ 407 (-1.69%)
Mutual labels:  attention
Transformer Tensorflow
TensorFlow implementation of 'Attention Is All You Need (2017. 6)'
Stars: ✭ 319 (-22.95%)
Mutual labels:  attention

Recurrent Visual Attention

This is a PyTorch implementation of Recurrent Models of Visual Attention by Volodymyr Mnih, Nicolas Heess, Alex Graves and Koray Kavukcuoglu.

Drawing

Drawing

The Recurrent Attention Model (RAM) is a neural network that processes inputs sequentially, attending to different locations within the image one at a time, and incrementally combining information from these fixations to build up a dynamic internal representation of the image.

Model Description

In this paper, the attention problem is modeled as the sequential decision process of a goal-directed agent interacting with a visual environment. The agent is built around a recurrent neural network: at each time step, it processes the sensor data, integrates information over time, and chooses how to act and how to deploy its sensor at the next time step.

Drawing

  • glimpse sensor: a retina that extracts a foveated glimpse phi around location l from an image x. It encodes the region around l at a high-resolution but uses a progressively lower resolution for pixels further from l, resulting in a compressed representation of the original image x.
  • glimpse network: a network that combines the "what" (phi) and the "where" (l) into a glimpse feature vector wg_t.
  • core network: an RNN that maintains an internal state that integrates information extracted from the history of past observations. It encodes the agent's knowledge of the environment through a state vector h_t that gets updated at every time step t.
  • location network: uses the internal state h_t of the core network to produce the location coordinates l_t for the next time step.
  • action network: after a fixed number of time steps, uses the internal state h_t of the core network to produce the final output classification y.

Results

I decided to tackle the 28x28 MNIST task with the RAM model containing 6 glimpses, of size 8x8, with a scale factor of 1.

Model Validation Error Test Error
6 8x8 1.1 1.21

I haven't done random search on the policy standard deviation to tune it, so I expect the test error can be reduced to sub 1% error. I'll be updating the table above with results for the 60x60 Translated MNIST, 60x60 Cluttered Translated MNIST and the new Fashion MNIST dataset when I get the time.

Finally, here's an animation showing the glimpses extracted by the network on a random batch at epoch 23.

Drawing

With the Adam optimizer, paper accuracy can be reached in ~160 epochs.

Usage

The easiest way to start training your RAM variant is to edit the parameters in config.py and run the following command:

python main.py

To resume training, run:

python main.py --resume=True

Finally, to test a checkpoint of your model that has achieved the best validation accuracy, run the following command:

python main.py --is_train=False

References

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].