All Projects → henry-prior → jax-rl

henry-prior / jax-rl

Licence: MIT license
JAX implementations of core Deep RL algorithms

Programming Languages

python
139335 projects - #7 most used programming language
Makefile
30231 projects

Projects that are alternatives of or similar to jax-rl

omd
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"
Stars: ✭ 43 (-29.51%)
Mutual labels:  deep-reinforcement-learning, flax, sac, jax, soft-actor-critic
Deep-Reinforcement-Learning-With-Python
Master classic RL, deep RL, distributional RL, inverse RL, and more using OpenAI Gym and TensorFlow with extensive Math
Stars: ✭ 222 (+263.93%)
Mutual labels:  deep-reinforcement-learning, sac, actor-critic, td3
Meta-SAC
Auto-tune the Entropy Temperature of Soft Actor-Critic via Metagradient - 7th ICML AutoML workshop 2020
Stars: ✭ 19 (-68.85%)
Mutual labels:  deep-reinforcement-learning, sac, mujoco, soft-actor-critic
proto
Proto-RL: Reinforcement Learning with Prototypical Representations
Stars: ✭ 67 (+9.84%)
Mutual labels:  sac, mujoco, soft-actor-critic
Pytorch sac
PyTorch implementation of Soft Actor-Critic (SAC)
Stars: ✭ 174 (+185.25%)
Mutual labels:  deep-reinforcement-learning, actor-critic, mujoco
Paddle-RLBooks
Paddle-RLBooks is a reinforcement learning code study guide based on pure PaddlePaddle.
Stars: ✭ 113 (+85.25%)
Mutual labels:  sac, actor-critic, td3
Pytorch A2c Ppo Acktr Gail
PyTorch implementation of Advantage Actor Critic (A2C), Proximal Policy Optimization (PPO), Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation (ACKTR) and Generative Adversarial Imitation Learning (GAIL).
Stars: ✭ 2,632 (+4214.75%)
Mutual labels:  deep-reinforcement-learning, actor-critic, mujoco
Tianshou
An elegant PyTorch deep reinforcement learning library.
Stars: ✭ 4,109 (+6636.07%)
Mutual labels:  sac, mujoco, td3
pomdp-baselines
Simple (but often Strong) Baselines for POMDPs in PyTorch - ICML 2022
Stars: ✭ 162 (+165.57%)
Mutual labels:  deep-reinforcement-learning, sac, td3
LWDRLC
Lightweight deep RL Libraray for continuous control.
Stars: ✭ 14 (-77.05%)
Mutual labels:  deep-reinforcement-learning, sac, td3
Rainy
☔ Deep RL agents with PyTorch☔
Stars: ✭ 39 (-36.07%)
Mutual labels:  deep-reinforcement-learning, sac, td3
Drq
DrQ: Data regularized Q
Stars: ✭ 268 (+339.34%)
Mutual labels:  deep-reinforcement-learning, actor-critic, mujoco
Pytorch sac ae
PyTorch implementation of Soft Actor-Critic + Autoencoder(SAC+AE)
Stars: ✭ 94 (+54.1%)
Mutual labels:  deep-reinforcement-learning, actor-critic, mujoco
Advanced Deep Learning And Reinforcement Learning Deepmind
🎮 Advanced Deep Learning and Reinforcement Learning at UCL & DeepMind | YouTube videos 👉
Stars: ✭ 121 (+98.36%)
Mutual labels:  deep-reinforcement-learning, deepmind
Baby A3c
A high-performance Atari A3C agent in 180 lines of PyTorch
Stars: ✭ 144 (+136.07%)
Mutual labels:  deep-reinforcement-learning, actor-critic
Reinforcementlearning Atarigame
Pytorch LSTM RNN for reinforcement learning to play Atari games from OpenAI Universe. We also use Google Deep Mind's Asynchronous Advantage Actor-Critic (A3C) Algorithm. This is much superior and efficient than DQN and obsoletes it. Can play on many games
Stars: ✭ 118 (+93.44%)
Mutual labels:  deep-reinforcement-learning, actor-critic
Minimalrl
Implementations of basic RL algorithms with minimal lines of codes! (pytorch based)
Stars: ✭ 2,051 (+3262.3%)
Mutual labels:  deep-reinforcement-learning, sac
Machine Learning Is All You Need
🔥🌟《Machine Learning 格物志》: ML + DL + RL basic codes and notes by sklearn, PyTorch, TensorFlow, Keras & the most important, from scratch!💪 This repository is ALL You Need!
Stars: ✭ 173 (+183.61%)
Mutual labels:  deep-reinforcement-learning, actor-critic
Hierarchical Actor Critic Hac Pytorch
PyTorch implementation of Hierarchical Actor Critic (HAC) for OpenAI gym environments
Stars: ✭ 116 (+90.16%)
Mutual labels:  deep-reinforcement-learning, actor-critic
Hands On Intelligent Agents With Openai Gym
Code for Hands On Intelligent Agents with OpenAI Gym book to get started and learn to build deep reinforcement learning agents using PyTorch
Stars: ✭ 189 (+209.84%)
Mutual labels:  deep-reinforcement-learning, actor-critic

jax-rl

Core Deep Reinforcement Learning algorithms using JAX for improved performance relative to PyTorch and TensorFlow. Control tasks rely on the DeepMind Control Suite or OpenAI Gym. DeepMind has recently open-sourced the MuJoCo physics engine, which is a dependency of this repo. If you haven't already set up MuJoCo, see the download site and copy the unzipped folder to a .mujoco folder in your base directory.

Current implementations

  • TD3
  • SAC
  • MPO

Environment and Testing

This repo makes use of the poetry package and dependency management tool. To build a local environment with all necessary packages run:

make init

To test local changes run:

make test

Run

To run each algorithm with DeepMind Control Suite as the environment backend on cartpole swingup from the base directory:

python jax_rl/main_dm_control.py --policy TD3 --max_timestep 100000
python jax_rl/main_dm_control.py --policy SAC --max_timesteps 100000
python jax_rl/main_dm_control.py --policy MPO --max_timesteps 100000

To use the OpenAI Gym environment backend use the jax_rl/main_gym.py file instead.

Results

Speed

As one would hope, the time per training step is significantly faster between JAX and other leading deep learning frameworks. The following comparison is the time in seconds per 1000 training steps with the same hyperparameters. For those interested in hardware, these were all run on the same machine, mid-2019 MacBook Pro 15-inch - 2.3 GHz Intel Core i9 - 16GB RAM.

JAX PyTorch
TD3 2.35 ± 0.17 6.93 ± 0.16
SAC 5.57 ± 0.07 32.37 ± 1.32
MPO 39.19 ± 1.09 107.51 ± 3.56

Performance

Evaluation of deterministic policy (acting according to the mean of the policy distributions for SAC and MPO) every 5000 training steps for each algorithm. Important parameters are constant for all, including batch size of 256 per training step, 10000 samples to the replay buffer with uniform random sampling before training, and 250000 total steps in the environment.

Notes on MPO Implementation

Because we have direct access to the jacobian function with JAX, I've opted to use scipy.optimize.minimize instead of taking a single gradient step on the temperature parameter per iteration. In my testing this gives much greater stability with only a marginal increase in time per iteration.

One important aspect to note if you are benchmarking these two approaches is that a standard profiler will be misleading. Most of the time will show up in the call to scipy.optimize.minimize, but this is due to how JAX calls work internally. JAX does not wait for an operation to complete when an operation is called, but rather returns a pointer to a DeviceArray whose value will be updated when the dispatched call is complete. If this object is passed into another JAX method, the same process will be repeated and control will be returned to Python. Any time Python attempts to access the value of a DeviceArray it will need to wait for the computation to complete. Because scipy.optimize.minimize passed the values of the parameter and the gradient to FORTRAN, this step will require the whole program to wait for all previous JAX calls to complete. To get a more accurate comparison, compare the total time per training step. To read more about how asynchronous dispatch works in JAX, see this reference.

I've run a quick comparison of the two following the same procedure as the Speed section above.

Sequential Least Squares Gradient Descent
39.19 ± 1.09 38.26 ± 2.74
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].