All Projects → omron-sinicx → ShinRL

omron-sinicx / ShinRL

Licence: other
ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives (Deep RL Workshop 2021)

Programming Languages

Jupyter Notebook
11667 projects
python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to ShinRL

Trax
Trax — Deep Learning with Clear Code and Speed
Stars: ✭ 6,666 (+22120%)
Mutual labels:  jax
omd
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"
Stars: ✭ 43 (+43.33%)
Mutual labels:  jax
revisiting rainbow
Revisiting Rainbow
Stars: ✭ 71 (+136.67%)
Mutual labels:  jax
Foolbox
A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX
Stars: ✭ 2,108 (+6926.67%)
Mutual labels:  jax
Transformers
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Stars: ✭ 55,742 (+185706.67%)
Mutual labels:  jax
parallel-non-linear-gaussian-smoothers
Companion code in JAX for the paper Parallel Iterated Extended and Sigma-Point Kalman Smoothers.
Stars: ✭ 17 (-43.33%)
Mutual labels:  jax
Datasets
TFDS is a collection of datasets ready to use with TensorFlow, Jax, ...
Stars: ✭ 3,094 (+10213.33%)
Mutual labels:  jax
uvadlc notebooks
Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2022/Spring 2022
Stars: ✭ 901 (+2903.33%)
Mutual labels:  jax
graphsignal
Graphsignal Python agent
Stars: ✭ 158 (+426.67%)
Mutual labels:  jax
jax-models
Unofficial JAX implementations of deep learning research papers
Stars: ✭ 108 (+260%)
Mutual labels:  jax
Jax
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Stars: ✭ 15,579 (+51830%)
Mutual labels:  jax
Einops
Deep learning operations reinvented (for pytorch, tensorflow, jax and others)
Stars: ✭ 4,022 (+13306.67%)
Mutual labels:  jax
jax-rl
JAX implementations of core Deep RL algorithms
Stars: ✭ 61 (+103.33%)
Mutual labels:  jax
Flax
Flax is a neural network library for JAX that is designed for flexibility.
Stars: ✭ 2,447 (+8056.67%)
Mutual labels:  jax
mlp-gpt-jax
A GPT, made only of MLPs, in Jax
Stars: ✭ 53 (+76.67%)
Mutual labels:  jax
Pyprobml
Python code for "Machine learning: a probabilistic perspective" (2nd edition)
Stars: ✭ 4,197 (+13890%)
Mutual labels:  jax
ADAM
ADAM implements a collection of algorithms for calculating rigid-body dynamics in Jax, CasADi, PyTorch, and Numpy.
Stars: ✭ 51 (+70%)
Mutual labels:  jax
GPJax
A didactic Gaussian process package for researchers in Jax.
Stars: ✭ 159 (+430%)
Mutual labels:  jax
rA9
JAX-based Spiking Neural Network framework
Stars: ✭ 60 (+100%)
Mutual labels:  jax
jax-cfd
Computational Fluid Dynamics in JAX
Stars: ✭ 399 (+1230%)
Mutual labels:  jax

Status: Under development (expect bug fixes and huge updates)

ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives

ShinRL is an open-source JAX library specialized for the evaluation of reinforcement learning (RL) algorithms from both theoretical and practical perspectives. Please take a look at the paper for details. Try ShinRL at experiments/QuickStart.ipynb.

QuickStart

QuickStart

import gym
from shinrl import DiscreteViSolver
import matplotlib.pyplot as plt

# make an env & a config
env = gym.make("ShinPendulum-v0")
config = DiscreteViSolver.DefaultConfig(explore="eps_greedy", approx="nn", steps_per_epoch=10000)

# make & run a solver
mixins = DiscreteViSolver.make_mixins(env, config)
dqn_solver = DiscreteViSolver.factory(env, config, mixins)
dqn_solver.run()

# plot performance
returns = dqn_solver.scalars["Return"]
plt.plot(returns["x"], returns["y"])

# plot learned q-values  (action == 0)
q0 = dqn_solver.data["Q"][:, 0]
env.plot_S(q0, title="Learned")

Example

Key Modules

overview

🔬 ShinEnv for Oracle Analysis

  • ShinEnv provides small environments with oracle methods that can compute exact quantities.
  • Some environments support continuous action space and image observation:
  • See the tutorial for details: experiments/Tutorials/ShinEnvTutorial.ipynb.
Environment Discrete action Continuous action Image Observation Tuple Observation
ShinMaze ✔️ ✔️
ShinMountainCar-v0 ✔️ ✔️ ✔️ ✔️
ShinPendulum-v0 ✔️ ✔️ ✔️ ✔️
ShinCartPole-v0 ✔️ ✔️ ✔️

🏭 Flexible Solver by MixIn

  • A Solver solves an environment with specified algorithms.
  • A "mixin" is a class which defines and implements a single feature. ShinRL's solvers are instantiated by mixing some mixins.
  • See the tutorial for details: experiments/Tutorials/SolverTutorial.ipynb.

MixIn

Implemented Popular Algorithms

  • The table bellow lists the implemented popular algorithms.
  • Note that it does not list all the implemented algorithms (e.g., DDP 1 version of the DQN algorithm). See make_mixin functions of solvers for implemented variants.
  • Note that the implemented algorithms may differ from the original implementation for simplicity (e.g., Discrete SAC). See source code of solvers for details.
Algorithm Solver Configuration Type 1
Value Iteration (VI) DiscreteViSolver approx == "tabular" & explore == "oracle" TDP
Policy Iteration (PI) DiscretePiSolver approx == "tabular" & explore == "oracle" TDP
Conservative Value Iteration (CVI) DiscreteViSolver approx == "tabular" & explore == "oracle & er_coef != 0 & kl_coef != 0" TDP
Tabular Q Learning DiscreteViSolver approx == "tabular" & explore != "oracle" TRL
SARSA DiscretePiSolver approx == "tabular" & explore != "oracle" & eps_decay_target_pol > 0 TRL
Deep Q Network (DQN) DiscreteViSolver approx == "nn" & explore != "oracle" DRL
Soft DQN DiscreteViSolver approx == "nn" & explore != "oracle" & er_coef != 0 DRL
Munchausen-DQN DiscreteViSolver approx == "nn" & explore != "oracle" & er_coef != 0 & kl_coef != 0 DRL
Double-DQN DiscreteViSolver approx == "nn" & explore != "oracle" & use_double_q == True DRL
Discrete Soft Actor Critic DiscretePiSolver approx == "nn" & explore != "oracle" & er_coef != 0 DRL
Deep Deterministic Policy Gradient (DDPG) ContinuousDdpgSolver approx == "nn" & explore != "oracle" DRL

1 Algorithm Type:

  • TDP (approx=="tabular" & explore=="oracle"): Tabular Dynamic Programming algorithms. No exploration & no approximation & the complete specification about the MDP is given.
  • TRL (approx=="tabular" & explore!="oracle"): Tabular Reinforcement Learning algorithms. No approximation & the dynamics and the reward functions are unknown.
  • DDP (approx=="nn" & explore=="oracle"): Deep Dynamic Programming algorithms. It is the same as TDP, except that neural networks approximate computed values.
  • DRL (approx=="nn" & explore!="oracle"): Deep Reinforcement Learning algorithms. It is the same as TRL, except that neural networks approximate computed values.

Installation

git clone [email protected]:omron-sinicx/ShinRL.git
cd ShinRL
pip install -e .

Test

cd ShinRL
make test

Format

cd ShinRL
make format

Docker

cd ShinRL
docker-compose up

Citation

# Neurips DRL WS 2021 version (pytorch branch)
@inproceedings{toshinori2021shinrl,
    author = {Kitamura, Toshinori and Yonetani, Ryo},
    title = {ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives},
    year = {2021},
    booktitle = {Proceedings of the NeurIPS Deep RL Workshop},
}

# Arxiv version (commit 2d3da)
@article{toshinori2021shinrlArxiv,
    author = {Kitamura, Toshinori and Yonetani, Ryo},
    title = {ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives},
    year = {2021},
    url = {https://arxiv.org/abs/2112.04123},
    journal={arXiv preprint arXiv:2112.04123},
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].