Top 42 jax open source projects

Einops
Deep learning operations reinvented (for pytorch, tensorflow, jax and others)
Jax
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Foolbox
A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX
Flax
Flax is a neural network library for JAX that is designed for flexibility.
Pyprobml
Python code for "Machine learning: a probabilistic perspective" (2nd edition)
Datasets
TFDS is a collection of datasets ready to use with TensorFlow, Jax, ...
jax-resnet
Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
treeo
A small library for creating and manipulating custom JAX Pytree classes
score flow
Official code for "Maximum Likelihood Training of Score-Based Diffusion Models", NeurIPS 2021 (spotlight)
annotated-s4
Implementation of https://srush.github.io/annotated-s4
brax
Massively parallel rigidbody physics simulation on accelerator hardware.
get-started-with-JAX
The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
jaxdf
A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
koclip
KoCLIP: Korean port of OpenAI CLIP, in Flax
madam
👩 Pytorch and Jax code for the Madam optimiser.
fedpa
Federated posterior averaging implemented in JAX
robustness-vit
Contains code for the paper "Vision Transformers are Robust Learners" (AAAI 2022).
wax-ml
A Python library for machine-learning and feedback loops on streaming data
jaxfg
Factor graphs and nonlinear optimization for JAX
dm pix
PIX is an image processing library in JAX, for JAX.
efficientnet-jax
EfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax
ShinRL
ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives (Deep RL Workshop 2021)
GPJax
A didactic Gaussian process package for researchers in Jax.
uvadlc notebooks
Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2022/Spring 2022
jax-cfd
Computational Fluid Dynamics in JAX
parallel-non-linear-gaussian-smoothers
Companion code in JAX for the paper Parallel Iterated Extended and Sigma-Point Kalman Smoothers.
ADAM
ADAM implements a collection of algorithms for calculating rigid-body dynamics in Jax, CasADi, PyTorch, and Numpy.
1-42 of 42 jax projects