All Projects → lucidrains → enformer-pytorch

lucidrains / enformer-pytorch

Licence: MIT license
Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to enformer-pytorch

switchde
Inference of switch-like differential expression along single-cell trajectories
Stars: ✭ 19 (-86.99%)
Mutual labels:  genomics, gene-expression
h-transformer-1d
Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning
Stars: ✭ 121 (-17.12%)
Mutual labels:  transformer, attention-mechanism
Transformers-RL
An easy PyTorch implementation of "Stabilizing Transformers for Reinforcement Learning"
Stars: ✭ 107 (-26.71%)
Mutual labels:  transformer, attention-mechanism
Eeg Dl
A Deep Learning library for EEG Tasks (Signals) Classification, based on TensorFlow.
Stars: ✭ 165 (+13.01%)
Mutual labels:  transformer, attention-mechanism
graphsim
R package: Simulate Expression data from igraph network using mvtnorm (CRAN; JOSS)
Stars: ✭ 16 (-89.04%)
Mutual labels:  genomics, gene-expression
Linear Attention Transformer
Transformer based on a variant of attention that is linear complexity in respect to sequence length
Stars: ✭ 205 (+40.41%)
Mutual labels:  transformer, attention-mechanism
En-transformer
Implementation of E(n)-Transformer, which extends the ideas of Welling's E(n)-Equivariant Graph Neural Network to attention
Stars: ✭ 131 (-10.27%)
Mutual labels:  transformer, attention-mechanism
Eqtransformer
EQTransformer, a python package for earthquake signal detection and phase picking using AI.
Stars: ✭ 95 (-34.93%)
Mutual labels:  transformer, attention-mechanism
NLP-paper
🎨 🎨NLP 自然语言处理教程 🎨🎨 https://dataxujing.github.io/NLP-paper/
Stars: ✭ 23 (-84.25%)
Mutual labels:  transformer, attention-mechanism
go enrichment
Transcripts annotation and GO enrichment Fisher tests
Stars: ✭ 24 (-83.56%)
Mutual labels:  genomics, gene-expression
Routing Transformer
Fully featured implementation of Routing Transformer
Stars: ✭ 149 (+2.05%)
Mutual labels:  transformer, attention-mechanism
visualization
a collection of visualization function
Stars: ✭ 189 (+29.45%)
Mutual labels:  transformer, attention-mechanism
Transformer In Generating Dialogue
An Implementation of 'Attention is all you need' with Chinese Corpus
Stars: ✭ 121 (-17.12%)
Mutual labels:  transformer, attention-mechanism
Self Attention Cv
Implementation of various self-attention mechanisms focused on computer vision. Ongoing repository.
Stars: ✭ 209 (+43.15%)
Mutual labels:  transformer, attention-mechanism
Overlappredator
[CVPR 2021, Oral] PREDATOR: Registration of 3D Point Clouds with Low Overlap.
Stars: ✭ 106 (-27.4%)
Mutual labels:  transformer, attention-mechanism
TianChi AIEarth
TianChi AIEarth Contest Solution
Stars: ✭ 57 (-60.96%)
Mutual labels:  transformer, attention-mechanism
Sockeye
Sequence-to-sequence framework with a focus on Neural Machine Translation based on Apache MXNet
Stars: ✭ 990 (+578.08%)
Mutual labels:  transformer, attention-mechanism
Se3 Transformer Pytorch
Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. This specific repository is geared towards integration with eventual Alphafold2 replication.
Stars: ✭ 73 (-50%)
Mutual labels:  transformer, attention-mechanism
CrabNet
Predict materials properties using only the composition information!
Stars: ✭ 57 (-60.96%)
Mutual labels:  transformer, attention-mechanism
dodrio
Exploring attention weights in transformer-based models with linguistic knowledge.
Stars: ✭ 233 (+59.59%)
Mutual labels:  transformer, attention-mechanism

Enformer - Pytorch

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. This repository also contains the means to fine tune pretrained models for your downstream tasks. The original tensorflow sonnet code can be found here.

Install

$ pip install enformer-pytorch

Usage

import torch
from enformer_pytorch import Enformer

model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313, mouse = 1643),
    target_length = 896,
)
    
seq = torch.randint(0, 5, (1, 196_608)) # for ACGTN, in that order (-1 for padding)
output = model(seq)

output['human'] # (1, 896, 5313)
output['mouse'] # (1, 896, 1643)

You can also directly pass in the sequence as one-hot encodings, which must be float values

import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot

model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313, mouse = 1643),
    target_length = 896,
)

seq = torch.randint(0, 5, (1, 196_608))
one_hot = seq_indices_to_one_hot(seq)

output = model(one_hot)

output['human'] # (1, 896, 5313)
output['mouse'] # (1, 896, 1643)

Finally, one can fetch the embeddings, for fine-tuning and otherwise, by setting the return_embeddings flag to be True on forward

import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot

model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313, mouse = 1643),
    target_length = 896,
)

seq = torch.randint(0, 5, (1, 196_608))
one_hot = seq_indices_to_one_hot(seq)

output, embeddings = model(one_hot, return_embeddings = True)

embeddings # (1, 896, 3072)

For training, you can directly pass the head and target in to get the poisson loss

import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot

model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313, mouse = 1643),
    target_length = 200,
).cuda()

seq = torch.randint(0, 5, (196_608 // 2,)).cuda()
target = torch.randn(200, 5313).cuda()

loss = model(
    seq,
    head = 'human',
    target = target
)

loss.backward()

# after much training

corr_coef = model(
    seq,
    head = 'human',
    target = target,
    return_corr_coef = True
)

corr_coef # pearson R, used as a metric in the paper

Pretrained Model

Warning: the pretrained models so far have not hit the mark of what was presented in the paper. if you would like to help out, please join this discord. replication efforts ongoing

To use a pretrained model (may not be of the same quality as the one in the paper yet), simply use the from_pretrained method (powered by HuggingFace):

from enformer_pytorch import Enformer

model = Enformer.from_pretrained("EleutherAI/enformer-preview")

# do your fine-tuning

This is made possible thanks to HuggingFace's custom model feature. All Enformer checkpoints can be found on the hub.

You can also load, with overriding of the target_length parameter, if you are working with shorter sequence lengths

from enformer_pytorch import Enformer

model = Enformer.from_pretrained('EleutherAI/enformer-preview', target_length = 128, dropout_rate = 0.1)

# do your fine-tuning

To save on memory during fine-tuning a large Enformer model

from enformer_pytorch import Enformer

enformer = Enformer.from_pretrained('EleutherAI/enformer-preview', use_checkpointing = True)

# finetune enformer on a limited budget

Fine-tuning

This repository will also allow for easy fine-tuning of Enformer.

Fine-tuning on new tracks

import torch
from enformer_pytorch import Enformer
from enformer_pytorch.finetune import HeadAdapterWrapper

enformer = Enformer.from_hparams(
    dim = 1536,
    depth = 1,
    heads = 8,
    target_length = 200,
)
    
model = HeadAdapterWrapper(
    enformer = enformer,
    num_tracks = 128
).cuda()

seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()
target = torch.randn(1, 200, 128).cuda()  # 128 tracks

loss = model(seq, target = target)
loss.backward()

Finetuning on contextual data (cell type, transcription factor, etc)

import torch
from enformer_pytorch import Enformer
from enformer_pytorch.finetune import ContextAdapterWrapper

enformer = Enformer.from_hparams(
    dim = 1536,
    depth = 1,
    heads = 8,
    target_length = 200,
)
    
model = ContextAdapterWrapper(
    enformer = enformer,
    context_dim = 1024
).cuda()

seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()

target = torch.randn(1, 200, 4).cuda()  # 4 tracks
context = torch.randn(4, 1024).cuda()   # 4 contexts for the different 'tracks'

loss = model(
    seq,
    context = context,
    target = target
)

loss.backward()

Finally, there is also a way to use attention aggregation from a set of context embeddings (or a single context embedding). Simply use the ContextAttentionAdapterWrapper

import torch
from enformer_pytorch import Enformer
from enformer_pytorch.finetune import ContextAttentionAdapterWrapper

enformer = Enformer.from_hparams(
    dim = 1536,
    depth = 1,
    heads = 8,
    target_length = 200,
)
    
model = ContextAttentionAdapterWrapper(
    enformer = enformer,
    context_dim = 1024,
    heads = 8,              # number of heads in the cross attention
    dim_head = 64           # dimension per head
).cuda()

seq = torch.randint(0, 5, (1, 196_608 // 2,)).cuda()

target = torch.randn(1, 200, 4).cuda()      # 4 tracks
context = torch.randn(4, 16, 1024).cuda()   # 4 contexts for the different 'tracks', each with 16 tokens

context_mask = torch.ones(4, 16).bool().cuda() # optional context mask, in example, include all context tokens

loss = model(
    seq,
    context = context,
    context_mask = context_mask,
    target = target
)

loss.backward()

Data

You can use the GenomicIntervalDataset to easily fetch sequences of any length from a .bed file, with greater context length dynamically computed if specified

import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset

filter_train = lambda df: df.filter(pl.col('column_4') == 'train')

ds = GenomeIntervalDataset(
    bed_file = './sequences.bed',                       # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
    fasta_file = './hg38.ml.fa',                        # path to fasta file
    filter_df_fn = filter_train,                        # filter dataframe function
    return_seq_indices = True,                          # return nucleotide indices (ACGTN) or one hot encodings
    shift_augs = (-2, 2),                               # random shift augmentations from -2 to +2 basepairs
    context_length = 196_608,
    # this can be longer than the interval designated in the .bed file,
    # in which case it will take care of lengthening the interval on either sides
    # as well as proper padding if at the end of the chromosomes
    chr_bed_to_fasta_map = {
        'chr1': 'chromosome1',  # if the chromosome name in the .bed file is different than the key name in the fasta file, you can rename them on the fly
        'chr2': 'chromosome2',
        'chr3': 'chromosome3',
        # etc etc
    }
)

model = Enformer.from_hparams(
    dim = 1536,
    depth = 11,
    heads = 8,
    output_heads = dict(human = 5313, mouse = 1643),
    target_length = 896,
)

seq = ds[0] # (196608,)
pred = model(seq, head = 'human') # (896, 5313)

To return the random shift value, as well as whether reverse complement was activated (in the case you need to reverse the corresponding chip-seq target data), just set return_augs = True when initializing the GenomicIntervalDataset

import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset

filter_train = lambda df: df.filter(pl.col('column_4') == 'train')

ds = GenomeIntervalDataset(
    bed_file = './sequences.bed',                       # bed file - columns 0, 1, 2 must be <chromosome>, <start position>, <end position>
    fasta_file = './hg38.ml.fa',                        # path to fasta file
    filter_df_fn = filter_train,                        # filter dataframe function
    return_seq_indices = True,                          # return nucleotide indices (ACGTN) or one hot encodings
    shift_augs = (-2, 2),                               # random shift augmentations from -2 to +2 basepairs
    rc_aug = True,                                      # use reverse complement augmentation with 50% probability
    context_length = 196_608,
    return_augs = True                                  # return the augmentation meta data
)

seq, rand_shift_val, rc_bool = ds[0] # (196608,), (1,), (1,)

Appreciation

Special thanks goes out to EleutherAI for providing the resources to retrain the model in an acceptable amount of time

Todo

  • script to load weights from trained tensorflow enformer model to pytorch model
  • add loss wrapper with poisson loss
  • move the metrics code over to pytorch as well
  • train enformer model
  • build context manager for fine-tuning with unfrozen enformer but with frozen batchnorm
  • allow for plain fine-tune with fixed static context
  • allow for fine tuning with only unfrozen layernorms (technique from fine tuning transformers)
  • fix handling of 'N' in sequence, figure out representation of N in basenji barnyard
  • take care of shift augmentation in GenomicIntervalDataset
  • speed up str_to_seq_indices
  • add to EleutherAI huggingface (done thanks to Niels)
  • offer some basic training utils, as gradient accumulation will be needed for fine tuning

Citations

@article {Avsec2021.04.07.438649,
    author  = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
    title   = {Effective gene expression prediction from sequence by integrating long-range interactions},
    elocation-id = {2021.04.07.438649},
    year    = {2021},
    doi     = {10.1101/2021.04.07.438649},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649},
    eprint  = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf},
    journal = {bioRxiv}
}
@misc{liu2022convnet,
    title   = {A ConvNet for the 2020s},
    author  = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
    year    = {2022},
    eprint  = {2201.03545},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].