All Projects → lucidrains → axial-attention

lucidrains / axial-attention

Licence: MIT license
Implementation of Axial attention - attending to multi-dimensional data efficiently

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to axial-attention

amta-net
Asymmetric Multi-Task Attention Network for Prostate Bed Segmentation in CT Images
Stars: ✭ 26 (-89.39%)
Mutual labels:  attention-mechanism
organic-chemistry-reaction-prediction-using-NMT
organic chemistry reaction prediction using NMT with Attention
Stars: ✭ 30 (-87.76%)
Mutual labels:  attention-mechanism
En-transformer
Implementation of E(n)-Transformer, which extends the ideas of Welling's E(n)-Equivariant Graph Neural Network to attention
Stars: ✭ 131 (-46.53%)
Mutual labels:  attention-mechanism
Neural-Chatbot
A Neural Network based Chatbot
Stars: ✭ 68 (-72.24%)
Mutual labels:  attention-mechanism
Optic-Disc-Unet
Attention Unet model with post process for retina optic disc segmention
Stars: ✭ 77 (-68.57%)
Mutual labels:  attention-mechanism
ChangeFormer
Official PyTorch implementation of our IGARSS'22 paper: A Transformer-Based Siamese Network for Change Detection
Stars: ✭ 220 (-10.2%)
Mutual labels:  attention-mechanism
SA-DL
Sentiment Analysis with Deep Learning models. Implemented with Tensorflow and Keras.
Stars: ✭ 35 (-85.71%)
Mutual labels:  attention-mechanism
TS3000 TheChatBOT
Its a social networking chat-bot trained on Reddit dataset . It supports open bounded queries developed on the concept of Neural Machine Translation. Beware of its being sarcastic just like its creator 😝 BDW it uses Pytorch framework and Python3.
Stars: ✭ 20 (-91.84%)
Mutual labels:  attention-mechanism
LSTM-Attention
A Comparison of LSTMs and Attention Mechanisms for Forecasting Financial Time Series
Stars: ✭ 53 (-78.37%)
Mutual labels:  attention-mechanism
CIAN
Implementation of the Character-level Intra Attention Network (CIAN) for Natural Language Inference (NLI) upon SNLI and MultiNLI corpus
Stars: ✭ 17 (-93.06%)
Mutual labels:  attention-mechanism
memory-compressed-attention
Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"
Stars: ✭ 47 (-80.82%)
Mutual labels:  attention-mechanism
NARRE
This is our implementation of NARRE:Neural Attentional Regression with Review-level Explanations
Stars: ✭ 100 (-59.18%)
Mutual labels:  attention-mechanism
uniformer-pytorch
Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification tasks, debuted in ICLR 2022
Stars: ✭ 90 (-63.27%)
Mutual labels:  attention-mechanism
STAM-pytorch
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Stars: ✭ 109 (-55.51%)
Mutual labels:  attention-mechanism
LanguageModel-using-Attention
Pytorch implementation of a basic language model using Attention in LSTM network
Stars: ✭ 27 (-88.98%)
Mutual labels:  attention-mechanism
Video-Description-with-Spatial-Temporal-Attention
[ACM MM 2017 & IEEE TMM 2020] This is the Theano code for the paper "Video Description with Spatial Temporal Attention"
Stars: ✭ 53 (-78.37%)
Mutual labels:  attention-mechanism
hexia
Mid-level PyTorch Based Framework for Visual Question Answering.
Stars: ✭ 24 (-90.2%)
Mutual labels:  attention-mechanism
rnn-text-classification-tf
Tensorflow implementation of Attention-based Bidirectional RNN text classification.
Stars: ✭ 26 (-89.39%)
Mutual labels:  attention-mechanism
S2VT-seq2seq-video-captioning-attention
S2VT (seq2seq) video captioning with bahdanau & luong attention implementation in Tensorflow
Stars: ✭ 18 (-92.65%)
Mutual labels:  attention-mechanism
dgcnn
Clean & Documented TF2 implementation of "An end-to-end deep learning architecture for graph classification" (M. Zhang et al., 2018).
Stars: ✭ 21 (-91.43%)
Mutual labels:  attention-mechanism

Axial Attention

PyPI version

Implementation of Axial attention in Pytorch. A simple but powerful technique to attend to multi-dimensional data efficiently. It has worked wonders for me and many other researchers.

Simply add some positional encoding to your data and pass it into this handy class, specifying which dimension is considered the embedding, and how many axial dimensions to rotate through. All the permutating, reshaping, will be taken care of for you.

This paper was actually rejected on the basis of being too simple. And yet, it has since been used successfully in a number of applications, among those weather prediction, all-attention image segmentation. Just goes to show.

Install

$ pip install axial_attention

Usage

Image

import torch
from axial_attention import AxialAttention

img = torch.randn(1, 3, 256, 256)

attn = AxialAttention(
    dim = 3,               # embedding dimension
    dim_index = 1,         # where is the embedding dimension
    dim_heads = 32,        # dimension of each head. defaults to dim // heads if not supplied
    heads = 1,             # number of heads for multi-head attention
    num_dimensions = 2,    # number of axial dimensions (images is 2, video is 3, or more)
    sum_axial_out = True   # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)

attn(img) # (1, 3, 256, 256)

Channel-last image latents

import torch
from axial_attention import AxialAttention

img = torch.randn(1, 20, 20, 512)

attn = AxialAttention(
    dim = 512,           # embedding dimension
    dim_index = -1,      # where is the embedding dimension
    heads = 8,           # number of heads for multi-head attention
    num_dimensions = 2,  # number of axial dimensions (images is 2, video is 3, or more)
)

attn(img) # (1, 20, 20 ,512)

Video

import torch
from axial_attention import AxialAttention

video = torch.randn(1, 5, 128, 256, 256)

attn = AxialAttention(
    dim = 128,           # embedding dimension
    dim_index = 2,       # where is the embedding dimension
    heads = 8,           # number of heads for multi-head attention
    num_dimensions = 3,  # number of axial dimensions (images is 2, video is 3, or more)
)

attn(video) # (1, 5, 128, 256, 256)

Image Transformer, with reversible network

import torch
from torch import nn
from axial_attention import AxialImageTransformer

conv1x1 = nn.Conv2d(3, 128, 1)

transformer = AxialImageTransformer(
    dim = 128,
    depth = 12,
    reversible = True
)

img = torch.randn(1, 3, 512, 512)

transformer(conv1x1(img)) # (1, 3, 512, 512)

With axial positional embedding

import torch
from axial_attention import AxialAttention, AxialPositionalEmbedding

img = torch.randn(1, 512, 20, 20)

attn = AxialAttention(
    dim = 512,
    heads = 8,
    dim_index = 1
)

pos_emb = AxialPositionalEmbedding(
    dim = 512,
    shape = (20, 20)
)

img = pos_emb(img)  # (1, 512, 20, 20)  - now positionally embedded
img = attn(img)     # (1, 512, 20, 20)

Citation

@misc{ho2019axial,
    title  = {Axial Attention in Multidimensional Transformers},
    author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
    year   = {2019},
    archivePrefix = {arXiv}
}
@misc{wang2020axialdeeplab,
    title   = {Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation},
    author  = {Huiyu Wang and Yukun Zhu and Bradley Green and Hartwig Adam and Alan Yuille and Liang-Chieh Chen},
    year    = {2020},
    eprint  = {2003.07853},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{huang2019ccnet,
    title   = {Ccnet: Criss-cross attention for semantic segmentation},
    author  = {Huang, Zilong and Wang, Xinggang and Huang, Lichao and Huang, Chang and Wei, Yunchao and Liu, Wenyu},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision},
    pages   = {603--612},
    year    = {2019}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].