All Projects β†’ DarshanDeshpande β†’ jax-models

DarshanDeshpande / jax-models

Licence: Apache-2.0 license
Unofficial JAX implementations of deep learning research papers

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to jax-models

chef-transformer
Chef Transformer 🍲 .
Stars: ✭ 29 (-73.15%)
Mutual labels:  transformers, flax, jax
efficientnet-jax
EfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax
Stars: ✭ 114 (+5.56%)
Mutual labels:  flax, jax
koclip
KoCLIP: Korean port of OpenAI CLIP, in Flax
Stars: ✭ 80 (-25.93%)
Mutual labels:  flax, jax
score flow
Official code for "Maximum Likelihood Training of Score-Based Diffusion Models", NeurIPS 2021 (spotlight)
Stars: ✭ 49 (-54.63%)
Mutual labels:  flax, jax
uvadlc notebooks
Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2022/Spring 2022
Stars: ✭ 901 (+734.26%)
Mutual labels:  flax, jax
jax-resnet
Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
Stars: ✭ 61 (-43.52%)
Mutual labels:  flax, jax
get-started-with-JAX
The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
Stars: ✭ 229 (+112.04%)
Mutual labels:  flax, jax
omd
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"
Stars: ✭ 43 (-60.19%)
Mutual labels:  flax, jax
Pyprobml
Python code for "Machine learning: a probabilistic perspective" (2nd edition)
Stars: ✭ 4,197 (+3786.11%)
Mutual labels:  flax, jax
Transformers
πŸ€— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Stars: ✭ 55,742 (+51512.96%)
Mutual labels:  flax, jax
robustness-vit
Contains code for the paper "Vision Transformers are Robust Learners" (AAAI 2022).
Stars: ✭ 78 (-27.78%)
Mutual labels:  transformers, jax
jax-rl
JAX implementations of core Deep RL algorithms
Stars: ✭ 61 (-43.52%)
Mutual labels:  flax, jax
KoBERT-Transformers
KoBERT on πŸ€— Huggingface Transformers πŸ€— (with Bug Fixed)
Stars: ✭ 162 (+50%)
Mutual labels:  transformers
KB-ALBERT
KBκ΅­λ―Όμ€ν–‰μ—μ„œ μ œκ³΅ν•˜λŠ” 경제/금육 도메인에 νŠΉν™”λœ ν•œκ΅­μ–΄ ALBERT λͺ¨λΈ
Stars: ✭ 215 (+99.07%)
Mutual labels:  transformers
COCO-LM
[NeurIPS 2021] COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining
Stars: ✭ 109 (+0.93%)
Mutual labels:  transformers
graphsignal
Graphsignal Python agent
Stars: ✭ 158 (+46.3%)
Mutual labels:  jax
parallel-non-linear-gaussian-smoothers
Companion code in JAX for the paper Parallel Iterated Extended and Sigma-Point Kalman Smoothers.
Stars: ✭ 17 (-84.26%)
Mutual labels:  jax
LIT
[AAAI 2022] This is the official PyTorch implementation of "Less is More: Pay Less Attention in Vision Transformers"
Stars: ✭ 79 (-26.85%)
Mutual labels:  transformers
nlp-papers
Must-read papers on Natural Language Processing (NLP)
Stars: ✭ 87 (-19.44%)
Mutual labels:  transformers
Nlp Architect
A model library for exploring state-of-the-art deep learning topologies and techniques for optimizing Natural Language Processing neural networks
Stars: ✭ 2,768 (+2462.96%)
Mutual labels:  transformers

JAX Models

license-shield release-shield python-shield code-style

Table of Contents
  1. About The Project
  2. Getting Started
  3. Contributing
  4. License
  5. Contact

About The Project

The JAX Models repository aims to provide open sourced JAX/Flax implementations for research papers originally without code or code written with frameworks other than JAX. The goal of this project is to make a collection of models, layers, activations and other utilities that are most commonly used for research. All papers and derived or translated code is cited in either the README or the docstrings. If you think that any citation is missed then please raise an issue.

All implementations provided here are available on Papers With Code.


Available model implementations for JAX are:
  1. MetaFormer is Actually What You Need for Vision (Weihao Yu et al., 2021)
  2. Augmenting Convolutional networks with attention-based aggregation (Hugo Touvron et al., 2021)
  3. MPViT : Multi-Path Vision Transformer for Dense Prediction (Youngwan Lee et al., 2021)
  4. MLP-Mixer: An all-MLP Architecture for Vision (Ilya Tolstikhin et al., 2021)
  5. Patches Are All You Need (Anonymous et al., 2021)
  6. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (Enze Xie et al., 2021)
  7. A ConvNet for the 2020s (Zhuang Liu et al., 2021)
  8. Masked Autoencoders Are Scalable Vision Learners (Kaiming He et al., 2021)
  9. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows (Ze Liu et al., 2021)
  10. Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions (Wenhai Wang et al., 2021)
  11. Going deeper with Image Transformers (Hugo Touvron et al., 2021)
  12. Visual Attention Network (Meng-Hao Guo et al., 2022)

Available layers for out-of-the-box integration:
  1. DropPath (Stochastic Depth) (Gao Huang et al., 2021)
  2. Squeeze-and-Excitation Layer (Jie Hu et al. 2019)
  3. Depthwise Convolution (François Chollet, 2017)

Prerequisites

Prerequisites can be installed separately through the requirements.txt file in the main directory using:

pip install -r requirements.txt

The use of a virtual environment is highly recommended to avoid version incompatibilites.

Installation

This project is built with Python 3 for the latest JAX/Flax versions and can be directly installed via pip.

pip install jax-models

If you wish to use the latest version then you can directly clone the repository too.

git clone https://github.com/DarshanDeshpande/jax-models.git

Usage

To see all model architectures available:

from jax_models import list_models
from pprint import pprint

pprint(list_models())

To load your desired model:

from jax_models import load_model
load_model('swin-tiny-224', attach_head=True, num_classes=1000, dropout=0.0, pretrained=True)

Note: It is necessary to pass attach_head=True and num_classes while loading pretrained models.

Contributing

Please raise an issue if any implementation gives incorrect results, crashes unexpectedly during training/inference or if any citation is missing.

You can contribute to jax_models by supporting me with compute resources or by contributing your own resources to provide pretrained weights.

If you wish to donate to this inititative then please drop me a mail here.

License

Distributed under the Apache 2.0 License. See LICENSE for more information.

Contact

Feel free to reach out for any issues or requests related to these implementations

Darshan Deshpande - Email | Twitter | LinkedIn

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].