All Projects → n2cholas → jax-resnet

n2cholas / jax-resnet

Licence: MIT License
Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to jax-resnet

get-started-with-JAX
The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
Stars: ✭ 229 (+275.41%)
Mutual labels:  flax, jax
score flow
Official code for "Maximum Likelihood Training of Score-Based Diffusion Models", NeurIPS 2021 (spotlight)
Stars: ✭ 49 (-19.67%)
Mutual labels:  flax, jax
Pyprobml
Python code for "Machine learning: a probabilistic perspective" (2nd edition)
Stars: ✭ 4,197 (+6780.33%)
Mutual labels:  flax, jax
omd
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"
Stars: ✭ 43 (-29.51%)
Mutual labels:  flax, jax
uvadlc notebooks
Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2022/Spring 2022
Stars: ✭ 901 (+1377.05%)
Mutual labels:  flax, jax
chef-transformer
Chef Transformer 🍲 .
Stars: ✭ 29 (-52.46%)
Mutual labels:  flax, jax
Transformers
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Stars: ✭ 55,742 (+91280.33%)
Mutual labels:  flax, jax
jax-rl
JAX implementations of core Deep RL algorithms
Stars: ✭ 61 (+0%)
Mutual labels:  flax, jax
jax-models
Unofficial JAX implementations of deep learning research papers
Stars: ✭ 108 (+77.05%)
Mutual labels:  flax, jax
efficientnet-jax
EfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax
Stars: ✭ 114 (+86.89%)
Mutual labels:  flax, jax
koclip
KoCLIP: Korean port of OpenAI CLIP, in Flax
Stars: ✭ 80 (+31.15%)
Mutual labels:  flax, jax
Retinal-Disease-Diagnosis-With-Residual-Attention-Networks
Using Residual Attention Networks to diagnose retinal diseases in medical images
Stars: ✭ 14 (-77.05%)
Mutual labels:  resnet
brax
Massively parallel rigidbody physics simulation on accelerator hardware.
Stars: ✭ 1,208 (+1880.33%)
Mutual labels:  jax
treeo
A small library for creating and manipulating custom JAX Pytree classes
Stars: ✭ 29 (-52.46%)
Mutual labels:  jax
Gradient-Samples
Samples for TensorFlow binding for .NET by Lost Tech
Stars: ✭ 53 (-13.11%)
Mutual labels:  resnet
flaxOptimizers
A collection of optimizers, some arcane others well known, for Flax.
Stars: ✭ 21 (-65.57%)
Mutual labels:  flax
Keras-CIFAR10
practice on CIFAR10 with Keras
Stars: ✭ 25 (-59.02%)
Mutual labels:  resnet
pyro-vision
Computer vision library for wildfire detection
Stars: ✭ 33 (-45.9%)
Mutual labels:  resnet
Faster-RCNN-TensorFlow
TensorFlow implementation of Faster RCNN for Object Detection
Stars: ✭ 13 (-78.69%)
Mutual labels:  resnet
bayex
Bayesian Optimization in JAX
Stars: ✭ 24 (-60.66%)
Mutual labels:  jax

JAX ResNet - Implementations and Checkpoints for ResNet Variants

Build & Tests

A Flax (Linen) implementation of ResNet (He et al. 2015), Wide ResNet (Zagoruyko & Komodakis 2016), ResNeXt (Xie et al. 2017), ResNet-D (He et al. 2020), and ResNeSt (Zhang et al. 2020). The code is modular so you can mix and match the various stem, residual, and bottleneck implementations.

Installation

You can install this package from PyPI:

pip install jax-resnet

Or directly from GitHub:

pip install --upgrade git+https://github.com/n2cholas/jax-resnet.git

Usage

See the bottom of jax-resnet/resnet.py for the available aliases/options for the ResNet variants (all models are in Flax)

Pretrained checkpoints from torch.hub are available for the following networks:

  • ResNet [18, 34, 50, 101, 152]
  • WideResNet [50, 101]
  • ResNeXt [50, 101]
  • ResNeSt [50-Fast, 50, 101, 200, 269]

The models are tested to have the same intermediate activations and outputs as the torch.hub implementations, except ResNeSt-50 Fast, whose activations don't match exactly but the final accuracy does.

A pretrained checkpoint for ResNetD-50 is available from fast.ai. The activations do not match exactly, but the final accuracy matches.

import jax.numpy as jnp
from jax_resnet import pretrained_resnest

ResNeSt50, variables = pretrained_resnest(50)
model = ResNeSt50()
out = model.apply(variables,
                  jnp.ones((32, 224, 224, 3)),  # ImageNet sized inputs.
                  mutable=False)  # Ensure `batch_stats` aren't updated.

You must install PyTorch yourself (instructions) to use these functions.

Transfer Learning

To extract a subset of the model, you can use Sequential(model.layers[start:end]).

The slice_variables function (found in in common.py) allows you to extract the corresponding subset of the variables dict. Check out that docstring for more information.

Checkpoint Accuracies

The top 1 and top 5 accuracies reported below are on the ImageNet2012 validation split. The data was preprocessed as in the official PyTorch example.

Model Size Top 1 Top 5
ResNet 18 69.75% 89.06%
34 73.29% 91.42%
50 76.13% 92.86%
101 77.37% 93.53%
152 78.30% 94.04%
Wide ResNet 50 78.48% 94.08%
101 78.88% 94.29%
ResNeXt 50 77.60% 93.70%
101 79.30% 94.51%
ResNet-D 50 77.57% 93.85%

The ResNeSt validation data was preprocessed as in zhang1989/ResNeSt.

Model Size Crop Size Top 1 Top 5
ResNeSt-Fast 50 224 80.53% 95.34%
ResNeSt 50 224 81.05% 95.42%
101 256 82.82% 96.32%
200 320 83.84% 96.86%
269 416 84.53% 96.98%

References

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].