Transformers🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
EinopsDeep learning operations reinvented (for pytorch, tensorflow, jax and others)
Thinc🔮 A refreshing functional take on deep learning, compatible with your favorite libraries
JaxComposable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
FoolboxA Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX
FlaxFlax is a neural network library for JAX that is designed for flexibility.
TraxTrax — Deep Learning with Clear Code and Speed
PyprobmlPython code for "Machine learning: a probabilistic perspective" (2nd edition)
DatasetsTFDS is a collection of datasets ready to use with TensorFlow, Jax, ...
jax-resnetImplementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
treeoA small library for creating and manipulating custom JAX Pytree classes
score flowOfficial code for "Maximum Likelihood Training of Score-Based Diffusion Models", NeurIPS 2021 (spotlight)
annotated-s4Implementation of https://srush.github.io/annotated-s4
braxMassively parallel rigidbody physics simulation on accelerator hardware.
get-started-with-JAXThe purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
bayexBayesian Optimization in JAX
jaxdfA JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
cr-sparseFunctional models and algorithms for sparse signal processing
koclipKoCLIP: Korean port of OpenAI CLIP, in Flax
madam👩 Pytorch and Jax code for the Madam optimiser.
fedpaFederated posterior averaging implemented in JAX
robustness-vitContains code for the paper "Vision Transformers are Robust Learners" (AAAI 2022).
wax-mlA Python library for machine-learning and feedback loops on streaming data
jaxfgFactor graphs and nonlinear optimization for JAX
dm pixPIX is an image processing library in JAX, for JAX.
efficientnet-jaxEfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax
ShinRLShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives (Deep RL Workshop 2021)
GPJaxA didactic Gaussian process package for researchers in Jax.
uvadlc notebooksRepository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2022/Spring 2022
rA9JAX-based Spiking Neural Network framework
jax-modelsUnofficial JAX implementations of deep learning research papers
jax-cfdComputational Fluid Dynamics in JAX
jax-rlJAX implementations of core Deep RL algorithms
ADAMADAM implements a collection of algorithms for calculating rigid-body dynamics in Jax, CasADi, PyTorch, and Numpy.
omdJAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"