All Projects → lucidrains → Linear Attention Transformer

lucidrains / Linear Attention Transformer

Licence: mit
Transformer based on a variant of attention that is linear complexity in respect to sequence length

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Linear Attention Transformer

Routing Transformer
Fully featured implementation of Routing Transformer
Stars: ✭ 149 (-27.32%)
Mutual labels:  artificial-intelligence, attention-mechanism, transformer
Self Attention Cv
Implementation of various self-attention mechanisms focused on computer vision. Ongoing repository.
Stars: ✭ 209 (+1.95%)
Mutual labels:  artificial-intelligence, attention-mechanism, transformer
Se3 Transformer Pytorch
Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
Stars: ✭ 73 (-64.39%)
Mutual labels:  artificial-intelligence, attention-mechanism, transformer
Point Transformer Pytorch
Implementation of the Point Transformer layer, in Pytorch
Stars: ✭ 199 (-2.93%)
Mutual labels:  artificial-intelligence, attention-mechanism
Mixture Of Experts
A Pytorch implementation of Sparsely-Gated Mixture of Experts, for massively increasing the parameter count of language models
Stars: ✭ 68 (-66.83%)
Mutual labels:  artificial-intelligence, transformer
Simplednn
SimpleDNN is a machine learning lightweight open-source library written in Kotlin designed to support relevant neural network architectures in natural language processing tasks
Stars: ✭ 81 (-60.49%)
Mutual labels:  artificial-intelligence, attention-mechanism
Awesome Bert Nlp
A curated list of NLP resources focused on BERT, attention mechanism, Transformer networks, and transfer learning.
Stars: ✭ 567 (+176.59%)
Mutual labels:  attention-mechanism, transformer
Reformer Pytorch
Reformer, the efficient Transformer, in Pytorch
Stars: ✭ 1,644 (+701.95%)
Mutual labels:  artificial-intelligence, attention-mechanism
Eqtransformer
EQTransformer, a python package for earthquake signal detection and phase picking using AI.
Stars: ✭ 95 (-53.66%)
Mutual labels:  attention-mechanism, transformer
Lambda Networks
Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute
Stars: ✭ 1,497 (+630.24%)
Mutual labels:  artificial-intelligence, attention-mechanism
Transformer In Generating Dialogue
An Implementation of 'Attention is all you need' with Chinese Corpus
Stars: ✭ 121 (-40.98%)
Mutual labels:  attention-mechanism, transformer
Perceiver Pytorch
Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch
Stars: ✭ 130 (-36.59%)
Mutual labels:  artificial-intelligence, attention-mechanism
Global Self Attention Network
A Pytorch implementation of Global Self-Attention Network, a fully-attention backbone for vision tasks
Stars: ✭ 64 (-68.78%)
Mutual labels:  artificial-intelligence, attention-mechanism
Sockeye
Sequence-to-sequence framework with a focus on Neural Machine Translation based on Apache MXNet
Stars: ✭ 990 (+382.93%)
Mutual labels:  attention-mechanism, transformer
Isab Pytorch
An implementation of (Induced) Set Attention Block, from the Set Transformers paper
Stars: ✭ 21 (-89.76%)
Mutual labels:  artificial-intelligence, attention-mechanism
Conformer
Implementation of the convolutional module from the Conformer paper, for use in Transformers
Stars: ✭ 103 (-49.76%)
Mutual labels:  artificial-intelligence, transformer
Slot Attention
Implementation of Slot Attention from GoogleAI
Stars: ✭ 168 (-18.05%)
Mutual labels:  artificial-intelligence, attention-mechanism
Nmt Keras
Neural Machine Translation with Keras
Stars: ✭ 501 (+144.39%)
Mutual labels:  attention-mechanism, transformer
Performer Pytorch
An implementation of Performer, a linear attention-based transformer, in Pytorch
Stars: ✭ 546 (+166.34%)
Mutual labels:  artificial-intelligence, attention-mechanism
Overlappredator
[CVPR 2021, Oral] PREDATOR: Registration of 3D Point Clouds with Low Overlap.
Stars: ✭ 106 (-48.29%)
Mutual labels:  attention-mechanism, transformer

Linear Attention Transformer

PyPI version

A fully featured Transformer that mixes (QKᵀ)V local attention with Q(KᵀV) global attention (scales linearly with respect to sequence length) for efficient long-range language modeling.

Install

$ pip install linear-attention-transformer

Usage

Language model

import torch
from linear_attention_transformer import LinearAttentionTransformerLM

model = LinearAttentionTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 1,
    max_seq_len = 8192,
    causal = True,                  # auto-regressive or not
    ff_dropout = 0.1,               # dropout for feedforward
    attn_layer_dropout = 0.1,       # dropout right after self-attention layer
    attn_dropout = 0.1,             # dropout post-attention
    emb_dim = 128,                  # embedding factorization, to save on memory
    dim_head = 128,                 # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    blindspot_size = 64,            # this gives the q(kv) attention a blindspot of 64 tokens back in the causal case, but gives back an order of magnitude return in memory savings. should be paired with local attention of at least a window size of this setting. setting this to 1 will allow for full q(kv) attention of past
    n_local_attn_heads = 4,         # number of local attention heads for (qk)v attention. this can be a tuple specifying the exact number of local attention heads at that depth
    local_attn_window_size = 128,   # receptive field of the local attention
    reversible = True,              # use reversible nets, from Reformer paper
    ff_chunks = 2,                  # feedforward chunking, from Reformer paper
    ff_glu = True,                  # use GLU variant for feedforward
    attend_axially = False          # will fold the sequence by the local attention window size, and do an extra strided attention followed by a feedforward with the cheap q(kv) attention
).cuda()

x = torch.randint(0, 20000, (1, 8192)).cuda()
model(x) # (1, 8192, 512)

Transformer

import torch
from linear_attention_transformer import LinearAttentionTransformer

model = LinearAttentionTransformer(
    dim = 512,
    heads = 8,
    depth = 1,
    max_seq_len = 8192,
    n_local_attn_heads = 4
).cuda()

x = torch.randn(1, 8192, 512).cuda()
model(x) # (1, 8192, 512)

Encoder / decoder

import torch
from linear_attention_transformer import LinearAttentionTransformerLM

enc = LinearAttentionTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 6,
    max_seq_len = 4096,
    one_kv_head = True,
    reversible = True,
    n_local_attn_heads = 4,
    return_embeddings = True
).cuda()

dec = LinearAttentionTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 6,
    causal = True,
    max_seq_len = 4096,
    one_kv_head = True,
    reversible = True,
    receives_context = True,
    n_local_attn_heads = 4
).cuda()

src = torch.randint(0, 20000, (1, 4096)).cuda()
src_mask = torch.ones_like(src).bool().cuda()

tgt = torch.randint(0, 20000, (1, 4096)).cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()

context = enc(src, input_mask = src_mask)
logits = dec(tgt, context = context, input_mask = tgt_mask, context_mask = src_mask)

Linformer

Linformer is another variant of attention with linear complexity championed by Facebook AI. It only works with non-autoregressive models of a fixed sequence length. If your problem satisfies that criteria, you may choose to try it out.

from linear_attention_transformer import LinearAttentionTransformerLM, LinformerSettings

settings = LinformerSettings(k = 256)

enc = LinearAttentionTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 6,
    max_seq_len = 4096,
    one_kv_head = True,
    linformer_settings = settings
).cuda()

You can also used Linformer for the contextual attention layer, if the contextual keys are of a fixed sequence length.

from linear_attention_transformer import LinearAttentionTransformerLM, LinformerContextSettings

settings = LinformerContextSettings(
  seq_len = 2048,
  k = 256
)

dec = LinearAttentionTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 6,
    max_seq_len = 4096,
    causal = True,
    one_kv_head = True,
    context_linformer_settings = settings,
    receives_context = True
).cuda()

Images

This repository also contains a concise implementation of this efficient attention for images

import torch
from linear_attention_transformer.images import ImageLinearAttention

attn =ImageLinearAttention(
  chan = 32,
  heads = 8,
  key_dim = 64       # can be decreased to 32 for more memory savings
)

img = torch.randn(1, 32, 256, 256)
attn(img) # (1, 32, 256, 256)

Citations

@inproceedings{katharopoulos-et-al-2020,
  author    = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
  title     = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
  booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
  year      = {2020},
  url       = {https://arxiv.org/abs/2006.16236}
}
@article{shen2019efficient,
  author    = {Zhuoran Shen and
               Mingyuan Zhang and
               Haiyu Zhao and
               Shuai Yi and
               Hongsheng Li},
  title     = {Efficient Attention: Attention with Linear Complexities},
  journal   = {CoRR},
  volume    = {abs/1812.01243},
  year      = {2018},
  url       = {http://arxiv.org/abs/1812.01243}
}
@misc{shazeer2019fast,
  title   = {Fast Transformer Decoding: One Write-Head is All You Need},
  author  = {Noam Shazeer},
  year    = {2019},
  eprint  = {1911.02150},
  archivePrefix = {arXiv}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{wang2020linformer,
    title   = {Linformer: Self-Attention with Linear Complexity},
    author  = {Sinong Wang and Belinda Z. Li and Madian Khabsa and Han Fang and Hao Ma},
    year    = {2020},
    eprint  = {2006.04768}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].