All Projects → JuliusKunze → Jaxnet

JuliusKunze / Jaxnet

Licence: apache-2.0
Concise deep learning for JAX

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Jaxnet

Knet.jl
Koç University deep learning framework.
Stars: ✭ 1,260 (+636.84%)
Mutual labels:  data-science, neural-networks
Sigmoidal ai
Tutoriais de Python, Data Science, Machine Learning e Deep Learning - Sigmoidal
Stars: ✭ 103 (-39.77%)
Mutual labels:  data-science, neural-networks
Vvedenie Mashinnoe Obuchenie
📝 Подборка ресурсов по машинному обучению
Stars: ✭ 1,282 (+649.71%)
Mutual labels:  data-science, neural-networks
Ai Platform
An open-source platform for automating tasks using machine learning models
Stars: ✭ 61 (-64.33%)
Mutual labels:  data-science, neural-networks
Fixy
Amacımız Türkçe NLP literatüründeki birçok farklı sorunu bir arada çözebilen, eşsiz yaklaşımlar öne süren ve literatürdeki çalışmaların eksiklerini gideren open source bir yazım destekleyicisi/denetleyicisi oluşturmak. Kullanıcıların yazdıkları metinlerdeki yazım yanlışlarını derin öğrenme yaklaşımıyla çözüp aynı zamanda metinlerde anlamsal analizi de gerçekleştirerek bu bağlamda ortaya çıkan yanlışları da fark edip düzeltebilmek.
Stars: ✭ 165 (-3.51%)
Mutual labels:  data-science, neural-networks
Mit Deep Learning
Tutorials, assignments, and competitions for MIT Deep Learning related courses.
Stars: ✭ 8,912 (+5111.7%)
Mutual labels:  data-science, neural-networks
Codesearchnet
Datasets, tools, and benchmarks for representation learning of code.
Stars: ✭ 1,378 (+705.85%)
Mutual labels:  data-science, neural-networks
Sciblog support
Support content for my blog
Stars: ✭ 694 (+305.85%)
Mutual labels:  data-science, neural-networks
Uncertainty Metrics
An easy-to-use interface for measuring uncertainty and robustness.
Stars: ✭ 145 (-15.2%)
Mutual labels:  data-science, neural-networks
Learn Machine Learning
Learn to Build a Machine Learning Application from Top Articles
Stars: ✭ 116 (-32.16%)
Mutual labels:  data-science, neural-networks
Mckinsey Smartcities Traffic Prediction
Adventure into using multi attention recurrent neural networks for time-series (city traffic) for the 2017-11-18 McKinsey IronMan (24h non-stop) prediction challenge
Stars: ✭ 49 (-71.35%)
Mutual labels:  data-science, neural-networks
Ml Workspace
🛠 All-in-one web-based IDE specialized for machine learning and data science.
Stars: ✭ 2,337 (+1266.67%)
Mutual labels:  data-science, neural-networks
Machine Learning From Scratch
Succinct Machine Learning algorithm implementations from scratch in Python, solving real-world problems (Notebooks and Book). Examples of Logistic Regression, Linear Regression, Decision Trees, K-means clustering, Sentiment Analysis, Recommender Systems, Neural Networks and Reinforcement Learning.
Stars: ✭ 42 (-75.44%)
Mutual labels:  data-science, neural-networks
Dltk
Deep Learning Toolkit for Medical Image Analysis
Stars: ✭ 1,249 (+630.41%)
Mutual labels:  data-science, neural-networks
Pyclustering
pyclustring is a Python, C++ data mining library.
Stars: ✭ 806 (+371.35%)
Mutual labels:  data-science, neural-networks
Tageditor
🏖TagEditor - Annotation tool for spaCy
Stars: ✭ 92 (-46.2%)
Mutual labels:  data-science, neural-networks
Test Tube
Python library to easily log experiments and parallelize hyperparameter search for neural networks
Stars: ✭ 663 (+287.72%)
Mutual labels:  data-science, neural-networks
Deep learning and the game of go
Code and other material for the book "Deep Learning and the Game of Go"
Stars: ✭ 677 (+295.91%)
Mutual labels:  data-science, neural-networks
Keras Contrib
Keras community contributions
Stars: ✭ 1,532 (+795.91%)
Mutual labels:  data-science, neural-networks
Autograd.jl
Julia port of the Python autograd package.
Stars: ✭ 147 (-14.04%)
Mutual labels:  data-science, neural-networks

JAXnet Build Status

JAXnet is a deep learning library based on JAX. JAXnet's functional API provides unique benefits over TensorFlow2, Keras and PyTorch, while maintaining user-friendliness, modularity and scalability:

  • More robustness through immutable weights, no global compute graph.
  • GPU-compiled numpy code for networks, training loops, pre- and postprocessing.
  • Regularization and reparametrization of any module or whole networks in one line.
  • No global random state, flexible random key control.

If you already know stax, read this.

Modularity

net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), log_softmax)

creates a neural net model from predefined modules.

Extensibility

Define your own modules using @parametrized functions. You can reuse other modules:

from jax import numpy as jnp

@parametrized
def loss(inputs, targets):
    return -jnp.mean(net(inputs) * targets)

All modules are composed in this way. jax.numpy is mirroring numpy, meaning that if you know how to use numpy, you know most of JAXnet. Compare this to TensorFlow2/Keras:

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Lambda

net = Sequential([Dense(1024, 'relu'), Dense(1024, 'relu'), Dense(4), Lambda(tf.nn.log_softmax)])

def loss(inputs, targets):
    return -tf.reduce_mean(net(inputs) * targets)

Notice how Lambda layers are not needed in JAXnet. relu and logsoftmax are plain Python functions.

Immutable weights

Different from TensorFlow2/Keras, JAXnet has no global compute graph. Modules like net and loss do not contain mutable weights. Instead, weights are contained in separate, immutable objects. They are initialized with init_parameters, provided example inputs and a random key:

from jax.random import PRNGKey

def next_batch(): return jnp.zeros((3, 784)), jnp.zeros((3, 4))

params = loss.init_parameters(*next_batch(), key=PRNGKey(0))

print(params.sequential.dense2.bias)  # [-0.01101029, -0.00749435, -0.00952365,  0.00493979]

Instead of mutating weights inline, optimizers return updated versions of weights. They are returned as part of a new optimizer state, and can be retrieved via get_parameters:

opt = optimizers.Adam()
state = opt.init(params)
for _ in range(10):
    state = opt.update(loss.apply, state, *next_batch()) # accelerate with jit=True

trained_params = opt.get_parameters(state)

apply evaluates a network:

test_loss = loss.apply(trained_params, *test_batch) # accelerate with jit=True

GPU support and compilation

JAX allows functional numpy/scipy code to be accelerated. Make it run on GPU by replacing your numpy import with jax.numpy. Compile a function by decorating it with jit. This will free your function from slow Python interpretation, parallelize operations where possible and optimize your compute graph. It provides speed and scalability at the level of TensorFlow2 or PyTorch.

Due to immutable weights, whole training loops can be compiled / run on GPU (demo). jit will make your training as fast as mutating weights inline, and weights will not leave the GPU. You can write functional code without worrying about performance.

You can easily accelerate numpy/scipy pre-/postprocessing code in the same way (demo).

Regularization and reparametrization

In JAXnet, regularizing a model can be done in one line (demo):

loss = L2Regularized(loss, scale=.1)

loss is now just another module that can be used as above. Reparametrized layers are one-liners, too (see API). JAXnet allows regularizing or reparametrizing any module or subnetwork without changing its code. This is possible because modules do not instantiate any variables. Instead each module provides a function (apply) with parameters as an argument. This function can be wrapped to build layers like L2Regularized.

In contrast, TensorFlow2/Keras/PyTorch have mutable variables baked into their model API. They therefore require:

  • Regularization arguments on layer level, with separate code necessary for each layer.
  • Reparametrization arguments on layer level, and separate implementations for each layer.

Random key control

JAXnet does not have global random state. Random keys are provided explicitly, making code deterministic and independent of previously executed code by default. This can help debugging and is more flexible (demo). Read more on random numbers in JAX here.

Step-by-step debugging

JAXnet allows step-by-step debugging with concrete values like any plain Python function (when jit compilation is not used).

API and demos

Find more details on the API here.

See JAXnet in action in your browser: Mnist Classifier, Mnist VAE, OCR with RNNs, ResNet, WaveNet, PixelCNN++ and Policy Gradient RL.

Installation PyPI

This is a preview. Expect breaking changes! Python 3.6 or higher is supported. Install with

pip3 install jaxnet

To use GPU, first install the right version of jaxlib.

Questions

Please feel free to create an issue on GitHub.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].