All Projects → lucidrains → g-mlp-pytorch

lucidrains / g-mlp-pytorch

Licence: MIT license
Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to g-mlp-pytorch

digitRecognition
Implementation of a digit recognition using my Neural Network with the MNIST data set.
Stars: ✭ 21 (-94.52%)
Mutual labels:  multilayer-perceptron
MLP
A multilayer perceptron in JavaScript
Stars: ✭ 15 (-96.08%)
Mutual labels:  multilayer-perceptron
dl-relu
Deep Learning using Rectified Linear Units (ReLU)
Stars: ✭ 20 (-94.78%)
Mutual labels:  multilayer-perceptron
pydata-london-2018
Slides and notebooks for my tutorial at PyData London 2018
Stars: ✭ 22 (-94.26%)
Mutual labels:  multilayer-perceptron
mlp-gpt-jax
A GPT, made only of MLPs, in Jax
Stars: ✭ 53 (-86.16%)
Mutual labels:  multilayer-perceptron

gMLP - Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Install

$ pip install g-mlp-pytorch

Usage

For masked language modelling

import torch
from torch import nn
from g_mlp_pytorch import gMLP

model = gMLP(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 256,
    circulant_matrix = True,      # use circulant weight matrix for linear increase in parameters in respect to sequence length
    act = nn.Tanh()               # activation for spatial gate (defaults to identity)
)

x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)

For image classification

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6
)

img = torch.randn(1, 3, 256, 256)
logits = model(img) # (1, 1000)

You can also add a tiny amount of attention (one-headed) to boost performance, as mentioned in the paper as aMLP, with the addition of one extra keyword attn_dim. This applies to both gMLPVision and gMLP

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    attn_dim = 64
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Non-square images and patch sizes

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = (256, 128),
    patch_size = (16, 8),
    num_classes = 1000,
    dim = 512,
    depth = 6,
    attn_dim = 64
)

img = torch.randn(1, 3, 256, 128)
pred = model(img) # (1, 1000)

Experimental

A independent researcher proposes using a multi-headed approach for gMLPs in a blogpost on Zhihu. To do so, just set heads to be greater than 1

import torch
from torch import nn
from g_mlp_pytorch import gMLP

model = gMLP(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 256,
    causal = True,
    circulant_matrix = True,
    heads = 4 # 4 heads
)

x = torch.randint(0, 20000, (1, 256))
logits = model(x) # (1, 256, 20000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = aug,
    year         = 2021,
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578%7D
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].