All Projects → lucidrains → En-transformer

lucidrains / En-transformer

Licence: MIT license
Implementation of E(n)-Transformer, which extends the ideas of Welling's E(n)-Equivariant Graph Neural Network to attention

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to En-transformer

Awesome Bert Nlp
A curated list of NLP resources focused on BERT, attention mechanism, Transformer networks, and transfer learning.
Stars: ✭ 567 (+332.82%)
Mutual labels:  transformer, attention-mechanism
Overlappredator
[CVPR 2021, Oral] PREDATOR: Registration of 3D Point Clouds with Low Overlap.
Stars: ✭ 106 (-19.08%)
Mutual labels:  transformer, attention-mechanism
Sockeye
Sequence-to-sequence framework with a focus on Neural Machine Translation based on Apache MXNet
Stars: ✭ 990 (+655.73%)
Mutual labels:  transformer, attention-mechanism
Pytorch Original Transformer
My implementation of the original transformer model (Vaswani et al.). I've additionally included the playground.py file for visualizing otherwise seemingly hard concepts. Currently included IWSLT pretrained models.
Stars: ✭ 411 (+213.74%)
Mutual labels:  transformer, attention-mechanism
Linear Attention Transformer
Transformer based on a variant of attention that is linear complexity in respect to sequence length
Stars: ✭ 205 (+56.49%)
Mutual labels:  transformer, attention-mechanism
Transformer Tts
A Pytorch Implementation of "Neural Speech Synthesis with Transformer Network"
Stars: ✭ 418 (+219.08%)
Mutual labels:  transformer, attention-mechanism
Eqtransformer
EQTransformer, a python package for earthquake signal detection and phase picking using AI.
Stars: ✭ 95 (-27.48%)
Mutual labels:  transformer, attention-mechanism
linformer
Implementation of Linformer for Pytorch
Stars: ✭ 119 (-9.16%)
Mutual labels:  transformer, attention-mechanism
Eeg Dl
A Deep Learning library for EEG Tasks (Signals) Classification, based on TensorFlow.
Stars: ✭ 165 (+25.95%)
Mutual labels:  transformer, attention-mechanism
Routing Transformer
Fully featured implementation of Routing Transformer
Stars: ✭ 149 (+13.74%)
Mutual labels:  transformer, attention-mechanism
Neural sp
End-to-end ASR/LM implementation with PyTorch
Stars: ✭ 408 (+211.45%)
Mutual labels:  transformer, attention-mechanism
Transformers-RL
An easy PyTorch implementation of "Stabilizing Transformers for Reinforcement Learning"
Stars: ✭ 107 (-18.32%)
Mutual labels:  transformer, attention-mechanism
Transformer
A TensorFlow Implementation of the Transformer: Attention Is All You Need
Stars: ✭ 3,646 (+2683.21%)
Mutual labels:  transformer, attention-mechanism
Nmt Keras
Neural Machine Translation with Keras
Stars: ✭ 501 (+282.44%)
Mutual labels:  transformer, attention-mechanism
galerkin-transformer
[NeurIPS 2021] Galerkin Transformer: a linear attention without softmax
Stars: ✭ 111 (-15.27%)
Mutual labels:  transformer, attention-mechanism
Se3 Transformer Pytorch
Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
Stars: ✭ 73 (-44.27%)
Mutual labels:  transformer, attention-mechanism
Transformer-in-Transformer
An Implementation of Transformer in Transformer in TensorFlow for image classification, attention inside local patches
Stars: ✭ 40 (-69.47%)
Mutual labels:  transformer, attention-mechanism
pynmt
a simple and complete pytorch implementation of neural machine translation system
Stars: ✭ 13 (-90.08%)
Mutual labels:  transformer, attention-mechanism
Transformer In Generating Dialogue
An Implementation of 'Attention is all you need' with Chinese Corpus
Stars: ✭ 121 (-7.63%)
Mutual labels:  transformer, attention-mechanism
Self Attention Cv
Implementation of various self-attention mechanisms focused on computer vision. Ongoing repository.
Stars: ✭ 209 (+59.54%)
Mutual labels:  transformer, attention-mechanism

E(n)-Equivariant Transformer

Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant Graph Neural Network with attention.

Install

$ pip install En-transformer

Usage

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    dim = 512,
    depth = 4,               # depth
    dim_head = 64,           # dimension per head
    heads = 8,               # number of heads
    edge_dim = 4,            # dimension of edge feature
    neighbors = 64,          # only do attention between coordinates N nearest neighbors - set to 0 to turn off
    talking_heads = True,    # use Shazeer's talking heads https://arxiv.org/abs/2003.02436
    checkpoint = True,       # use checkpointing so one can increase depth at little memory cost (and increase neighbors attended to)
    use_cross_product = True # use cross product vectors (idea by @MattMcPartlon)
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)

mask = torch.ones(1, 1024).bool()

feats, coors = model(feats, coors, edges, mask = mask)  # (1, 16, 512), (1, 16, 3)

Letting the network take care of both atomic and bond type embeddings

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    num_tokens = 10,       # number of unique nodes, say atoms
    rel_pos_emb = True,    # set this to true if your sequence is not an unordered set. it will accelerate convergence
    num_edge_tokens = 5,   # number of unique edges, say bond types
    dim = 128,
    edge_dim = 16,
    depth = 3,
    heads = 4,
    dim_head = 32,
    neighbors = 8
)

atoms = torch.randint(0, 10, (1, 16))    # 10 different types of atoms
bonds = torch.randint(0, 5, (1, 16, 16)) # 5 different types of bonds (n x n)
coors = torch.randn(1, 16, 3)            # atomic spatial coordinates

feats_out, coors_out = model(atoms, coors, edges = bonds) # (1, 16, 512), (1, 16, 3)

If you would like to only attend to sparse neighbors, as defined by an adjacency matrix (say for atoms), you have to set one more flag and then pass in the N x N adjacency matrix.

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    num_tokens = 10,
    dim = 512,
    depth = 1,
    heads = 4,
    dim_head = 32,
    neighbors = 0,
    only_sparse_neighbors = True,    # must be set to true
    num_adj_degrees = 2,             # the number of degrees to derive from 1st degree neighbors passed in
    adj_dim = 8                      # whether to pass the adjacency degree information as an edge embedding
)

atoms = torch.randint(0, 10, (1, 16))
coors = torch.randn(1, 16, 3)

# naively assume a single chain of atoms
i = torch.arange(atoms.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

# adjacency matrix must be passed in
feats_out, coors_out = model(atoms, coors, adj_mat = adj_mat) # (1, 16, 512), (1, 16, 3)

Edges

If you need to pass in continuous edges

import torch
from en_transformer import EnTransformer
from en_transformer.utils import rot

model = EnTransformer(
    dim = 512,
    depth = 1,
    heads = 4,
    dim_head = 32,
    edge_dim = 4,
    num_nearest_neighbors = 0,
    only_sparse_neighbors = True
)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

i = torch.arange(feats.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

feats1, coors1 = model(feats, coors, adj_mat = adj_mat, edges = edges)

Example

To run a protein backbone coordinate denoising toy task, first install sidechainnet

$ pip install sidechainnet

Then

$ python denoise.py

Todo

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{shazeer2020talkingheads,
    title   = {Talking-Heads Attention}, 
    author  = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
    year    = {2020},
    eprint  = {2003.02436},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].