All Projects → lucidrains → Se3 Transformer Pytorch

lucidrains / Se3 Transformer Pytorch

Licence: mit
Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Se3 Transformer Pytorch

Routing Transformer
Fully featured implementation of Routing Transformer
Stars: ✭ 149 (+104.11%)
Mutual labels:  artificial-intelligence, attention-mechanism, transformer
Linear Attention Transformer
Transformer based on a variant of attention that is linear complexity in respect to sequence length
Stars: ✭ 205 (+180.82%)
Mutual labels:  artificial-intelligence, attention-mechanism, transformer
Self Attention Cv
Implementation of various self-attention mechanisms focused on computer vision. Ongoing repository.
Stars: ✭ 209 (+186.3%)
Mutual labels:  artificial-intelligence, attention-mechanism, transformer
Neural sp
End-to-end ASR/LM implementation with PyTorch
Stars: ✭ 408 (+458.9%)
Mutual labels:  attention-mechanism, transformer
Vit Pytorch
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Stars: ✭ 7,199 (+9761.64%)
Mutual labels:  artificial-intelligence, attention-mechanism
Alphafold2
To eventually become an unofficial Pytorch implementation / replication of Alphafold2, as details of the architecture get released
Stars: ✭ 298 (+308.22%)
Mutual labels:  artificial-intelligence, attention-mechanism
pynmt
a simple and complete pytorch implementation of neural machine translation system
Stars: ✭ 13 (-82.19%)
Mutual labels:  transformer, attention-mechanism
Transformer Tts
A Pytorch Implementation of "Neural Speech Synthesis with Transformer Network"
Stars: ✭ 418 (+472.6%)
Mutual labels:  attention-mechanism, transformer
Bottleneck Transformer Pytorch
Implementation of Bottleneck Transformer in Pytorch
Stars: ✭ 408 (+458.9%)
Mutual labels:  artificial-intelligence, attention-mechanism
Omninet
Official Pytorch implementation of "OmniNet: A unified architecture for multi-modal multi-task learning" | Authors: Subhojeet Pramanik, Priyanka Agrawal, Aman Hussain
Stars: ✭ 448 (+513.7%)
Mutual labels:  artificial-intelligence, transformer
Awesome Bert Nlp
A curated list of NLP resources focused on BERT, attention mechanism, Transformer networks, and transfer learning.
Stars: ✭ 567 (+676.71%)
Mutual labels:  attention-mechanism, transformer
Timesformer Pytorch
Implementation of TimeSformer from Facebook AI, a pure attention-based solution for video classification
Stars: ✭ 225 (+208.22%)
Mutual labels:  artificial-intelligence, attention-mechanism
galerkin-transformer
[NeurIPS 2021] Galerkin Transformer: a linear attention without softmax
Stars: ✭ 111 (+52.05%)
Mutual labels:  transformer, attention-mechanism
Transformer
A TensorFlow Implementation of the Transformer: Attention Is All You Need
Stars: ✭ 3,646 (+4894.52%)
Mutual labels:  attention-mechanism, transformer
linformer
Implementation of Linformer for Pytorch
Stars: ✭ 119 (+63.01%)
Mutual labels:  transformer, attention-mechanism
Pytorch Original Transformer
My implementation of the original transformer model (Vaswani et al.). I've additionally included the playground.py file for visualizing otherwise seemingly hard concepts. Currently included IWSLT pretrained models.
Stars: ✭ 411 (+463.01%)
Mutual labels:  attention-mechanism, transformer
Performer Pytorch
An implementation of Performer, a linear attention-based transformer, in Pytorch
Stars: ✭ 546 (+647.95%)
Mutual labels:  artificial-intelligence, attention-mechanism
Isab Pytorch
An implementation of (Induced) Set Attention Block, from the Set Transformers paper
Stars: ✭ 21 (-71.23%)
Mutual labels:  artificial-intelligence, attention-mechanism
Global Self Attention Network
A Pytorch implementation of Global Self-Attention Network, a fully-attention backbone for vision tasks
Stars: ✭ 64 (-12.33%)
Mutual labels:  artificial-intelligence, attention-mechanism
Image-Caption
Using LSTM or Transformer to solve Image Captioning in Pytorch
Stars: ✭ 36 (-50.68%)
Mutual labels:  transformer, attention-mechanism

SE3 Transformer - Pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.

Open In Colab Example of equivariance

Install

$ pip install se3-transformer-pytorch

Usage

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Potential example usage in Alphafold2, as outlined here

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True,
    differentiable_coors = True
)

atom_feats = torch.randn(2, 32, 64)
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)

You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,       # 28 unique atoms
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(atoms, coors, mask, return_type = 1) # (2, 32, 3)

If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 2,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True  # reduce out the final dimension
)

atom_feats  = torch.randn(2, 32, 64, 1) # b x n x d x type0
coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1

# atom features are type 0, predicted coordinates are type 1
features = {'0': atom_feats, '1': coors_feats}
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refined_coors = coors + model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates

Edges

To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,       # number of edge type, say 4 bond types
    edge_dim = 16,             # dimension of edge embedding
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)

Scaling (wip)

This section will list ongoing efforts to make SE3 Transformer scale a little better.

Firstly, I have added reversible networks. This allows me to add a little more depth before hitting the usual memory roadblocks. Equivariance preservation is demonstrated in the tests.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 20,
    dim = 32,
    dim_head = 32,
    heads = 4,
    depth = 12,             # 12 layers
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True,
    reversible = True       # set reversible to True
).cuda()

atoms = torch.randint(0, 4, (2, 32)).cuda()
coors = torch.randn(2, 32, 3).cuda()
mask  = torch.ones(2, 32).bool().cuda()

pred = model(atoms, coors, mask = mask, return_type = 0)

loss = pred.sum()
loss.backward()

Todo:

  • [ ] Test to see if Performer maintains equivariance https://arxiv.org/abs/2009.14794
  • [ ] Chunking kernel calculation from basis vectors
  • [ ] When on float32, occasionally one would see breaking of equivariance. Figure out the root cause and see if only that part of the computation can be moved to float64 and back issue link

Caching

By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag CLEAR_CACHE to some value on initiating the script

$ CLEAR_CACHE=1 python train.py

Or you can try deleting the cache directory, which should exist at

$ rm -rf ~/.cache.equivariant_attention

Testing

$ python setup.py pytest

Credit

This library is largely a port of Fabian's official repository, but without the DGL library.

Citations

@misc{fuchs2020se3transformers,
    title   = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks}, 
    author  = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
    year    = {2020},
    eprint  = {2006.10503},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{gomez2017reversible,
    title     = {The Reversible Residual Network: Backpropagation Without Storing Activations},
    author    = {Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse},
    year      = {2017},
    eprint    = {1707.04585},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].