All Projects → google → Flax

google / Flax

Licence: apache-2.0
Flax is a neural network library for JAX that is designed for flexibility.

Programming Languages

python
139335 projects - #7 most used programming language
Jupyter Notebook
11667 projects

Labels

Projects that are alternatives of or similar to Flax

chef-transformer
Chef Transformer 🍲 .
Stars: ✭ 29 (-98.81%)
Mutual labels:  jax
bayex
Bayesian Optimization in JAX
Stars: ✭ 24 (-99.02%)
Mutual labels:  jax
treeo
A small library for creating and manipulating custom JAX Pytree classes
Stars: ✭ 29 (-98.81%)
Mutual labels:  jax
fedpa
Federated posterior averaging implemented in JAX
Stars: ✭ 38 (-98.45%)
Mutual labels:  jax
jaxdf
A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
Stars: ✭ 50 (-97.96%)
Mutual labels:  jax
brax
Massively parallel rigidbody physics simulation on accelerator hardware.
Stars: ✭ 1,208 (-50.63%)
Mutual labels:  jax
jaxfg
Factor graphs and nonlinear optimization for JAX
Stars: ✭ 124 (-94.93%)
Mutual labels:  jax
Trax
Trax — Deep Learning with Clear Code and Speed
Stars: ✭ 6,666 (+172.42%)
Mutual labels:  jax
ML-Optimizers-JAX
Toy implementations of some popular ML optimizers using Python/JAX
Stars: ✭ 37 (-98.49%)
Mutual labels:  jax
score flow
Official code for "Maximum Likelihood Training of Score-Based Diffusion Models", NeurIPS 2021 (spotlight)
Stars: ✭ 49 (-98%)
Mutual labels:  jax
madam
👩 Pytorch and Jax code for the Madam optimiser.
Stars: ✭ 46 (-98.12%)
Mutual labels:  jax
cr-sparse
Functional models and algorithms for sparse signal processing
Stars: ✭ 38 (-98.45%)
Mutual labels:  jax
SymJAX
Documentation:
Stars: ✭ 103 (-95.79%)
Mutual labels:  jax
robustness-vit
Contains code for the paper "Vision Transformers are Robust Learners" (AAAI 2022).
Stars: ✭ 78 (-96.81%)
Mutual labels:  jax
jax-resnet
Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
Stars: ✭ 61 (-97.51%)
Mutual labels:  jax
wax-ml
A Python library for machine-learning and feedback loops on streaming data
Stars: ✭ 36 (-98.53%)
Mutual labels:  jax
get-started-with-JAX
The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
Stars: ✭ 229 (-90.64%)
Mutual labels:  jax
Transformers
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Stars: ✭ 55,742 (+2177.97%)
Mutual labels:  jax
Pyprobml
Python code for "Machine learning: a probabilistic perspective" (2nd edition)
Stars: ✭ 4,197 (+71.52%)
Mutual labels:  jax
annotated-s4
Implementation of https://srush.github.io/annotated-s4
Stars: ✭ 133 (-94.56%)
Mutual labels:  jax

Flax: A neural network library and ecosystem for JAX designed for flexibility

Build coverage

Overview | Quick install | What does Flax look like? | Documentation

This README is a very short intro. To learn everything you need to know about Flax, see our full documentation

Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.

Flax is being used by a growing community of hundreds of folks in various Alphabet research departments for their daily work, as well as a growing community of open source projects.

The Flax team's mission is to serve the growing JAX neural network research ecosystem -- both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads. We are in the process of moving some remaining internal design docs and conversation threads to GitHub discussions, issues and pull requests. We hope to increasingly engage with the needs and clarifications of the broader ecosystem. Please let us know how we can help!

Please report any feature requests, issues, questions or concerns in our discussion forum, or just let us know what you're working on!

We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.

In case you want to reach us directly, we're at [email protected].

Overview

Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

  • Neural network API (flax.linen): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

  • Optimizers (flax.optim): SGD, Momentum, Adam, LARS, Adagrad, LAMB, RMSprop

  • Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device

  • Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging

  • Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b

Quick install

You will need Python 3.6 or later and a working JAX installation (with or without GPU support, see instructions there). For a CPU-only version:

> pip install --upgrade pip # To support manylinux2010 wheels.
> pip install --upgrade jax jaxlib # CPU-only

Then install Flax from PyPi:

> pip install flax

To upgrade to the latest version of Flax, you can use:

> pip install --upgrade git+https://github.com/google/flax.git

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the Module abstraction, see our docs, our broad intro to the Module abstraction. For additional concrete demonstrations of best practices, see our HOWTO guides.

from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

model = CNN()
batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)
class AutoEncoder(nn.Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Sequence[int]

  def setup(self):
    input_dim = np.prod(self.input_shape)
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (input_dim,))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

model = AutoEncoder(encoder_widths=[20, 10, 5],
                    decoder_widths=[5, 10, 20],
                    input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.PRNGKey(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)

🤗 Hugging Face

In-detail examples to train and evaluate a variety of Flax models for Natural Language Processing, Computer Vision, and Speech Recognition are actively maintained in the 🤗 Transformers repository.

As of October 2021, the 19 most-used Transformer architectures are supported in Flax and over 5000 pretrained checkpoints in Flax have been uploaded to the 🤗 Hub.

Citing Flax

To cite this repository:

@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.3.5},
  year = {2020},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.

Note

Flax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].