All Projects β†’ lucidrains β†’ mlp-gpt-jax

lucidrains / mlp-gpt-jax

Licence: MIT license
A GPT, made only of MLPs, in Jax

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to mlp-gpt-jax

Transformers
πŸ€— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Stars: ✭ 55,742 (+105073.58%)
Mutual labels:  language-model, jax
gap-text2sql
GAP-text2SQL: Learning Contextual Representations for Semantic Parsing with Generation-Augmented Pre-Training
Stars: ✭ 83 (+56.6%)
Mutual labels:  language-model
digitRecognition
Implementation of a digit recognition using my Neural Network with the MNIST data set.
Stars: ✭ 21 (-60.38%)
Mutual labels:  multilayer-perceptron
CharLM
Character-aware Neural Language Model implemented by PyTorch
Stars: ✭ 32 (-39.62%)
Mutual labels:  language-model
KB-ALBERT
KBκ΅­λ―Όμ€ν–‰μ—μ„œ μ œκ³΅ν•˜λŠ” 경제/금육 도메인에 νŠΉν™”λœ ν•œκ΅­μ–΄ ALBERT λͺ¨λΈ
Stars: ✭ 215 (+305.66%)
Mutual labels:  language-model
jax-cfd
Computational Fluid Dynamics in JAX
Stars: ✭ 399 (+652.83%)
Mutual labels:  jax
rnn-theano
RNN(LSTM, GRU) in Theano with mini-batch training; character-level language models in Theano
Stars: ✭ 68 (+28.3%)
Mutual labels:  language-model
LanguageModel-using-Attention
Pytorch implementation of a basic language model using Attention in LSTM network
Stars: ✭ 27 (-49.06%)
Mutual labels:  language-model
g-mlp-pytorch
Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch
Stars: ✭ 383 (+622.64%)
Mutual labels:  multilayer-perceptron
calm
Context Aware Language Models
Stars: ✭ 29 (-45.28%)
Mutual labels:  language-model
jax-rl
JAX implementations of core Deep RL algorithms
Stars: ✭ 61 (+15.09%)
Mutual labels:  jax
asr24
24-hour Automatic Speech Recognition
Stars: ✭ 27 (-49.06%)
Mutual labels:  language-model
jax-models
Unofficial JAX implementations of deep learning research papers
Stars: ✭ 108 (+103.77%)
Mutual labels:  jax
omd
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"
Stars: ✭ 43 (-18.87%)
Mutual labels:  jax
lm-scorer
πŸ“ƒLanguage Model based sentences scoring library
Stars: ✭ 264 (+398.11%)
Mutual labels:  language-model
Vaaku2Vec
Language Modeling and Text Classification in Malayalam Language using ULMFiT
Stars: ✭ 68 (+28.3%)
Mutual labels:  language-model
parallel-non-linear-gaussian-smoothers
Companion code in JAX for the paper Parallel Iterated Extended and Sigma-Point Kalman Smoothers.
Stars: ✭ 17 (-67.92%)
Mutual labels:  jax
personality-prediction
Experiments for automated personality detection using Language Models and psycholinguistic features on various famous personality datasets including the Essays dataset (Big-Five)
Stars: ✭ 109 (+105.66%)
Mutual labels:  language-model
dasher-web
Dasher text entry in HTML, CSS, JavaScript, and SVG
Stars: ✭ 34 (-35.85%)
Mutual labels:  language-model
swig-srilm
SWIG Wrapper for the SRILM toolkit
Stars: ✭ 33 (-37.74%)
Mutual labels:  language-model

MLP GPT - Jax

A GPT, made only of MLPs, in Jax. The specific MLP to be used are gMLPs with the Spatial Gating Units.

Working Pytorch implementation

Install

$ pip install mlp-gpt-jax

Usage

from jax import random
from haiku import PRNGSequence
from mlp_gpt_jax import TransformedMLPGpt

model = TransformedMLPGpt(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024
)

rng = PRNGSequence(0)
seq = random.randint(next(rng), (1024,), 0, 20000)

params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 20000)

To use the tiny attention (also made autoregressive with a causal mask), just set the attn_dim to the head dimension you'd like to use. 64 was recommended in the paper

from jax import random
from haiku import PRNGSequence
from mlp_gpt_jax import TransformedMLPGpt

model = TransformedMLPGpt(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    attn_dim = 64     # set this to 64
)

rng = PRNGSequence(0)
seq = random.randint(next(rng), (1024,), 0, 20000)

params = model.init(next(rng), seq)
logits = model.apply(params, next(rng), seq) # (1024, 20000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].