All Projects → srush → annotated-s4

srush / annotated-s4

Licence: other
Implementation of https://srush.github.io/annotated-s4

Programming Languages

python
139335 projects - #7 most used programming language
Makefile
30231 projects

Projects that are alternatives of or similar to annotated-s4

uvadlc notebooks
Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2022/Spring 2022
Stars: ✭ 901 (+577.44%)
Mutual labels:  jax
robustness-vit
Contains code for the paper "Vision Transformers are Robust Learners" (AAAI 2022).
Stars: ✭ 78 (-41.35%)
Mutual labels:  jax
ML-Optimizers-JAX
Toy implementations of some popular ML optimizers using Python/JAX
Stars: ✭ 37 (-72.18%)
Mutual labels:  jax
ShinRL
ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives (Deep RL Workshop 2021)
Stars: ✭ 30 (-77.44%)
Mutual labels:  jax
wax-ml
A Python library for machine-learning and feedback loops on streaming data
Stars: ✭ 36 (-72.93%)
Mutual labels:  jax
madam
👩 Pytorch and Jax code for the Madam optimiser.
Stars: ✭ 46 (-65.41%)
Mutual labels:  jax
mlp-gpt-jax
A GPT, made only of MLPs, in Jax
Stars: ✭ 53 (-60.15%)
Mutual labels:  jax
brax
Massively parallel rigidbody physics simulation on accelerator hardware.
Stars: ✭ 1,208 (+808.27%)
Mutual labels:  jax
chef-transformer
Chef Transformer 🍲 .
Stars: ✭ 29 (-78.2%)
Mutual labels:  jax
jaxdf
A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
Stars: ✭ 50 (-62.41%)
Mutual labels:  jax
efficientnet-jax
EfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax
Stars: ✭ 114 (-14.29%)
Mutual labels:  jax
jaxfg
Factor graphs and nonlinear optimization for JAX
Stars: ✭ 124 (-6.77%)
Mutual labels:  jax
koclip
KoCLIP: Korean port of OpenAI CLIP, in Flax
Stars: ✭ 80 (-39.85%)
Mutual labels:  jax
GPJax
A didactic Gaussian process package for researchers in Jax.
Stars: ✭ 159 (+19.55%)
Mutual labels:  jax
bayex
Bayesian Optimization in JAX
Stars: ✭ 24 (-81.95%)
Mutual labels:  jax
rA9
JAX-based Spiking Neural Network framework
Stars: ✭ 60 (-54.89%)
Mutual labels:  jax
fedpa
Federated posterior averaging implemented in JAX
Stars: ✭ 38 (-71.43%)
Mutual labels:  jax
SymJAX
Documentation:
Stars: ✭ 103 (-22.56%)
Mutual labels:  jax
get-started-with-JAX
The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
Stars: ✭ 229 (+72.18%)
Mutual labels:  jax
cr-sparse
Functional models and algorithms for sparse signal processing
Stars: ✭ 38 (-71.43%)
Mutual labels:  jax

Experiments

MNIST Sequence Modeling

# Default arguments
python -m s4.train --dataset mnist --model s4 --epochs 100 --bsz 128 --d_model 128 --ssm_n 64

QuickDraw Sequence Modeling

# Default arguments
python -m s4.train --dataset quickdraw --model s4 --epochs 10 --bsz 128 --d_model 128 --ssm_n 64

# "Run in a day" variant
python -m s4.train --dataset quickdraw --model s4 --epochs 1 --bsz 512 --d_model 256 --ssm_n 64 --p_dropout 0.05

MNIST Classification

# Default arguments
python -m s4.train --dataset mnist-classification --model s4 --epochs 10 --bsz 128 --d_model 128 --ssm_n 64

(Default Arguments, as shown above): Gets "best" 97.76% accuracy in 10 epochs @ 40s/epoch on a TitanRTX.

CIFAR-10 Classification

## Adding a Cubic Decay Schedule for last 70% of training

# Default arguments (100 epochs for CIFAR)
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 128 --d_model 128 --ssm_n 64 --lr 1e-2 --lr_schedule

# S4 replication from central repository
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 64 --d_model 512 --ssm_n 64 --lr 1e-2 --lr_schedule

## After Fixing S4-Custom Optimization & Dropout2D (all implemented inline now... can add flags if desired)

# Default arguments (100 epochs for CIFAR)
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 128 --d_model 128 --ssm_n 64 --lr 1e-2

# S4 replication from central repository
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 64 --d_model 512 --ssm_n 64 --lr 1e-2

## Before Fixing S4-Custom Optimization...

# Default arguments (100 epochs for CIFAR)
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 128 --d_model 128 --ssm_n 64

# S4 replication from central repository
python -m s4.train --dataset cifar-classification --model s4 --epochs 100 --bsz 64 --d_model 512 --ssm_n 64

Adding a Schedule:

  • (LR 1e-2 w/ Replication Args -- "big" model): 71.55% (still running, 39 epochs) @ 3m16s on a TitanRTX
  • (LR 1e-2 w/ Default Args -- not "bigger" model): 71.92% @ 36s/epoch on a TitanRTX

After Fixing Dropout2D (w/ Optimization in Place):

  • (LR 1e-2 w/ Replication Args -- "big" model): 70.68% (still running, 47 epochs) @ 3m17s on a TitanRTX
  • (LR 1e-2 w/ Default Args -- not "bigger" model): 68.20% @ 36s/epoch on a TitanRTX

After Fixing Optimization, Before Fixing Dropout2D:

  • (LR 1e-2 w/ Default Args -- not "bigger" model): 67.14% @ 36s/epoch on a TitanRTX

Before Fixing S4 Optimization -- AdamW w/ LR 1e-3 for ALL Parameters:

  • (Default Arguments): Gets "best" 63.51% accuracy @ 46s/epoch on a TitanRTX
  • (S4 Arguments): Gets "best" 66.44% accuracy @ 3m11s on a TitanRTX
    • Possible reasons for failure to meet replication: LR Schedule (Decay on Plateau), Custom LR per Parameter.

Quickstart (Development)

We have two requirements.txt files that hold dependencies for the current project: one that is tailored to CPUs, the other that installs for GPU.

CPU-Only (MacOS, Linux)

# Set up virtual/conda environment of your choosing & activate...
pip install -r requirements-cpu.txt

# Set up pre-commit
pre-commit install

GPU (CUDA > 11 & CUDNN > 8.2)

# Set up virtual/conda environment of your choosing & activate...
pip install -r requirements-gpu.txt

# Set up pre-commit
pre-commit install

Dependencies from Scratch

In case the above requirements.txt don't work, here are the commands used to download dependencies.

CPU-Only

# Set up virtual/conda environment of your choosing & activate... then install the following:
pip install --upgrade "jax[cpu]"
pip install flax
pip install torch torchvision torchaudio

# Defaults
pip install black celluloid flake8 google-cloud-storage isort ipython matplotlib pre-commit seaborn tensorflow tqdm

# Set up pre-commit
pre-commit install

GPU (CUDA > 11, CUDNN > 8.2)

Note - CUDNN > 8.2 is critical for compilation without warnings, and GPU w/ at least Turing architecture for full efficiency.

# Set up virtual/conda environment of your choosing & activate... then install the following:
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install flax
pip install torch==1.10.1+cpu torchvision==0.11.2+cpu torchaudio==0.10.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html

# Defaults
pip install black celluloid flake8 google-cloud-storage isort ipython matplotlib pre-commit seaborn tensorflow tqdm

# Set up pre-commit
pre-commit install
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].