All Projects → rwightman → efficientnet-jax

rwightman / efficientnet-jax

Licence: Apache-2.0 license
EfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax

Programming Languages

python
139335 projects - #7 most used programming language
Dockerfile
14818 projects

Projects that are alternatives of or similar to efficientnet-jax

Pytorch Image Models
PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, EfficientNetV2, NFNet, Vision Transformer, MixNet, MobileNet-V3/V2, RegNet, DPN, CSPNet, and more
Stars: ✭ 15,232 (+13261.4%)
Mutual labels:  mixnet, mobilenetv3, efficientnet
koclip
KoCLIP: Korean port of OpenAI CLIP, in Flax
Stars: ✭ 80 (-29.82%)
Mutual labels:  flax, jax
TPU-MobilenetSSD
Edge TPU Accelerator / Multi-TPU + MobileNet-SSD v2 + Python + Async + LattePandaAlpha/RaspberryPi3/LaptopPC
Stars: ✭ 82 (-28.07%)
Mutual labels:  tpu, mobilenetv2
uvadlc notebooks
Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2022/Spring 2022
Stars: ✭ 901 (+690.35%)
Mutual labels:  flax, jax
Tensorrtx
Implementation of popular deep learning networks with TensorRT network definition API
Stars: ✭ 3,456 (+2931.58%)
Mutual labels:  mobilenetv2, mobilenetv3
chef-transformer
Chef Transformer 🍲 .
Stars: ✭ 29 (-74.56%)
Mutual labels:  flax, jax
score flow
Official code for "Maximum Likelihood Training of Score-Based Diffusion Models", NeurIPS 2021 (spotlight)
Stars: ✭ 49 (-57.02%)
Mutual labels:  flax, jax
omd
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"
Stars: ✭ 43 (-62.28%)
Mutual labels:  flax, jax
Transformers
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Stars: ✭ 55,742 (+48796.49%)
Mutual labels:  flax, jax
Pyprobml
Python code for "Machine learning: a probabilistic perspective" (2nd edition)
Stars: ✭ 4,197 (+3581.58%)
Mutual labels:  flax, jax
awesome-computer-vision-models
A list of popular deep learning models related to classification, segmentation and detection problems
Stars: ✭ 419 (+267.54%)
Mutual labels:  mixnet, efficientnet
jax-models
Unofficial JAX implementations of deep learning research papers
Stars: ✭ 108 (-5.26%)
Mutual labels:  flax, jax
TensorMONK
A collection of deep learning models (PyTorch implemtation)
Stars: ✭ 21 (-81.58%)
Mutual labels:  mobilenetv2, efficientnet
get-started-with-JAX
The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
Stars: ✭ 229 (+100.88%)
Mutual labels:  flax, jax
jax-resnet
Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
Stars: ✭ 61 (-46.49%)
Mutual labels:  flax, jax
jax-rl
JAX implementations of core Deep RL algorithms
Stars: ✭ 61 (-46.49%)
Mutual labels:  flax, jax
MixNet-PyTorch
Concise, Modular, Human-friendly PyTorch implementation of MixNet with Pre-trained Weights.
Stars: ✭ 16 (-85.96%)
Mutual labels:  mixnet, efficientnet
MobilePose
Light-weight Single Person Pose Estimator
Stars: ✭ 588 (+415.79%)
Mutual labels:  mobilenetv2
tensorflow-yolov4
YOLOv4 Implemented in Tensorflow 2.
Stars: ✭ 136 (+19.3%)
Mutual labels:  tpu
efficientdet
PyTorch Implementation of the state-of-the-art model for object detection EfficientDet [pre-trained weights provided]
Stars: ✭ 21 (-81.58%)
Mutual labels:  efficientnet

EfficientNet JAX - Flax Linen and Objax

Acknowledgements

Verification of training code was made possible with Cloud TPUs via Google's TPU Research Cloud (TRC) (https://www.tensorflow.org/tfrc)

Intro

This is very much a giant steaming work in progress. Jax, jaxlib, and the NN libraries I'm using are shifting week to week.

This code base currently supports:

This is essentially an adaptation of my PyTorch EfficienNet generator code (https://github.com/rwightman/gen-efficientnet-pytorch and also found in https://github.com/rwightman/pytorch-image-models) to JAX.

I started this to

  • learn JAX by working with familiar code / models as a starting point,
  • figure out which JAX modelling interface libraries ('frameworks') I liked,
  • compare the training / inference runtime traits of non-trivial models across combinations of PyTorch, JAX, GPU and TPU in order to drive cost optimizations for scaling up of future projects

Where are we at:

  • Training works on single node, multi-GPU and TPU v3-8 for Flax Linen variants w/ Tensorflow Datasets based pipeline
  • The Objax and Flax Linen (nn.compact) variants of models are working (for inference)
  • Weights are ported from PyTorch (my timm training) and Tensorflow (original paper author releases) and are organized in zoo of sorts (borrowed PyTorch code)
  • Tensorflow and PyTorch data pipeline based validation scripts work with models and weights. For PT pipeline with PT models and TF pipeline with TF models the results are pretty much exact.

TODO:

  • Fix model weight inits (working for Flax Linen variants)
  • Fix dropout/drop path impl and other training specifics (verified for Flax Linen variants)
  • Add more instructions / help in the README on how to get an optimal environment with JAX up and running (with GPU support)
  • Add basic training code. The main point of this is to scale up training.
  • Add more advance data augmentation pipeline
  • Training on lots of GPUs
  • Training on lots of TPUs

Some odd things:

  • Objax layers are reimplemented to make my initial work easier, scratch some itches, make more consistent with PyTorch (because why not?)
  • Flax Linen layers are by default fairly consistent with Tensorflow (left as is)
  • I use wrappers around Flax Linen layers for some argument consistency and reduced visual noise (no redundant tuples)
  • I made a 'LIKE' padding mode, sort of like 'SAME' but different, hence the name. It calculates symmetric padding for PyTorch models.
  • Models with Tensorflow 'SAME' padding and TF origin weights are prefixed with tf_. Models with PyTorch trained weights and symmetric PyTorch style padding ('LIKE' here) are prefixed with pt_
  • I use pt and tf to refer to PyTorch and Tensorflow for both the models and environments. These two do not need to be used together. pt models with 'LIKE' padding will work fine running in a Tensorflow based environment and vice versa. I did this to show the full flexibility here, that one can use JAX models with PyTorch data pipelines and datasets or with Tensorflow based data pipelines and TFDS.

Models

Supported models and their paper's

Models by their config name w/ valid pretrained weights that should be working here:

pt_mnasnet_100
pt_semnasnet_100
pt_mobilenetv2_100
pt_mobilenetv2_110d
pt_mobilenetv2_120d
pt_mobilenetv2_140
pt_fbnetc_100
pt_spnasnet_100
pt_efficientnet_b0
pt_efficientnet_b1
pt_efficientnet_b2
pt_efficientnet_b3
tf_efficientnet_b0
tf_efficientnet_b1
tf_efficientnet_b2
tf_efficientnet_b3
tf_efficientnet_b4
tf_efficientnet_b5
tf_efficientnet_b6
tf_efficientnet_b7
tf_efficientnet_b8
tf_efficientnet_b0_ap
tf_efficientnet_b1_ap
tf_efficientnet_b2_ap
tf_efficientnet_b3_ap
tf_efficientnet_b4_ap
tf_efficientnet_b5_ap
tf_efficientnet_b6_ap
tf_efficientnet_b7_ap
tf_efficientnet_b8_ap
tf_efficientnet_b0_ns
tf_efficientnet_b1_ns
tf_efficientnet_b2_ns
tf_efficientnet_b3_ns
tf_efficientnet_b4_ns
tf_efficientnet_b5_ns
tf_efficientnet_b6_ns
tf_efficientnet_b7_ns
tf_efficientnet_l2_ns_475
tf_efficientnet_l2_ns
pt_efficientnet_es
pt_efficientnet_em
tf_efficientnet_es
tf_efficientnet_em
tf_efficientnet_el
pt_efficientnet_lite0
tf_efficientnet_lite0
tf_efficientnet_lite1
tf_efficientnet_lite2
tf_efficientnet_lite3
tf_efficientnet_lite4
pt_mixnet_s
pt_mixnet_m
pt_mixnet_l
pt_mixnet_xl
tf_mixnet_s
tf_mixnet_m
tf_mixnet_l
pt_mobilenetv3_large_100
tf_mobilenetv3_large_075
tf_mobilenetv3_large_100
tf_mobilenetv3_large_minimal_100
tf_mobilenetv3_small_075
tf_mobilenetv3_small_100
tf_mobilenetv3_small_minimal_100

Environment

Working with JAX I've found the best approach for having a working GPU compatible environment that performs well is to use Docker containers based on the latest NVIDIA NGC releases. I've found it challenging or flaky getting local conda/pip venvs or Tensorflow docker containers working well with good GPU performance, proper NCCL distributed support, etc. I use CPU JAX install in conda env for dev/debugging.

Dockerfiles

There are several container definitions in docker/. They use NGC containers as their parent image so you'll need to be setup to pull NGC containers: https://www.nvidia.com/en-us/gpu-cloud/containers/ . I'm currently using recent NGC containers w/ CUDA 11.1 support, the host system will need a very recent NVIDIA driver to support this but doesn't need a matching CUDA 11.1 / cuDNN 8 install.

Current dockerfiles:

  • pt_git.Dockerfile - PyTorch 20.12 NGC as parent, CUDA 11.1, cuDNN 8. git (source install) of jaxlib, jax, objax, and flax.
  • pt_pip.Dockerfile - PyTorch 20.12 NGC as parent, CUDA 11.1, cuDNN 8. pip (latest ver) install of jaxlib, jax, objax, and flax.
  • tf_git.Dockerfile - Tensorflow 2 21.02 NGC as parent, CUDA 11.2, cuDNN 8. git (source install) of jaxlib, jax, objax, and flax.
  • tf_pip.Dockerfile - Tensorflow 2 21.02 NGC as parent, CUDA 11.2, cuDNN 8. pip (latest ver) install of jaxlib, jax, objax, and flax.

The 'git' containers take some time to build jaxlib, they pull the masters of all respective repos so are up to the bleeding edge but more likely to have possible regression or incompatibilities that go with that. The pip install containers are quite a bit quicker to get up and running, based on the latest pip versions of all repos.

Docker Usage (GPU)

  1. Make sure you have a recent version of docker and the NVIDIA Container Toolkit setup (https://github.com/NVIDIA/nvidia-docker)
  2. Build the container docker build -f docker/tf_pip.Dockerfile -t jax_tf_pip .
  3. Run the container, ideally map jeffnet and datasets (ImageNet) into the container
    • For tf containers, docker run --gpus all -it -v /path/to/tfds/root:/data/ -v /path/to/efficientnet-jax/:/workspace/jeffnet --rm --ipc=host jax_tf_pip
    • For pt containers, docker run --gpus all -it -v /path/to/imagenet/root:/data/ -v /path/to/efficientnet-jax/:/workspace/jeffnet --rm --ipc=host jax_pt_pip
  4. Model validation w/ pretrained weights (once inside running container):
    • For tf, in worskpace/jeffnet, python tf_linen_validate.py /data/ --model tf_efficientnet_b0_ns
    • For pt, in worskpace/jeffnet, python pt_objax_validate.py /data/validation --model pt_efficientnet_b0
  5. Training (within container)
    • In worskpace/jeffnet, tf_linen_train.py --config train_configs/tf_efficientnet_b0-gpu_24gb_x2.py --config.data_dir /data

TPU

I've successfully used this codebase on TPU VM environments as is. Any of the tpu_x8 training configs should work out of the box on a v3-8 TPU. I have not tackled training with TPU Pods.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].