All Projects → chen0040 → java-reinforcement-learning

chen0040 / java-reinforcement-learning

Licence: MIT license
Package provides java implementation of reinforcement learning algorithms such Q-Learn, R-Learn, SARSA, Actor-Critic

Programming Languages

java
68154 projects - #9 most used programming language
shell
77523 projects
Batchfile
5799 projects

Projects that are alternatives of or similar to java-reinforcement-learning

Reinforcement Learning With Tensorflow
Simple Reinforcement learning tutorials, 莫烦Python 中文AI教学
Stars: ✭ 6,948 (+7620%)
Mutual labels:  q-learning, sarsa, actor-critic, sarsa-lambda
Paddle-RLBooks
Paddle-RLBooks is a reinforcement learning code study guide based on pure PaddlePaddle.
Stars: ✭ 113 (+25.56%)
Mutual labels:  q-learning, sarsa, actor-critic
Explorer
Explorer is a PyTorch reinforcement learning framework for exploring new ideas.
Stars: ✭ 54 (-40%)
Mutual labels:  q-learning, actor-critic
RL
Reinforcement Learning Demos
Stars: ✭ 66 (-26.67%)
Mutual labels:  q-learning, sarsa
Deep-Reinforcement-Learning-With-Python
Master classic RL, deep RL, distributional RL, inverse RL, and more using OpenAI Gym and TensorFlow with extensive Math
Stars: ✭ 222 (+146.67%)
Mutual labels:  q-learning, actor-critic
Reinforcement learning tutorial with demo
Reinforcement Learning Tutorial with Demo: DP (Policy and Value Iteration), Monte Carlo, TD Learning (SARSA, QLearning), Function Approximation, Policy Gradient, DQN, Imitation, Meta Learning, Papers, Courses, etc..
Stars: ✭ 442 (+391.11%)
Mutual labels:  q-learning, actor-critic
Dissecting Reinforcement Learning
Python code, PDFs and resources for the series of posts on Reinforcement Learning which I published on my personal blog
Stars: ✭ 512 (+468.89%)
Mutual labels:  q-learning, actor-critic
Easy Rl
强化学习中文教程,在线阅读地址:https://datawhalechina.github.io/easy-rl/
Stars: ✭ 3,004 (+3237.78%)
Mutual labels:  q-learning, sarsa
reinforcement learning with Tensorflow
Minimal implementations of reinforcement learning algorithms by Tensorflow
Stars: ✭ 28 (-68.89%)
Mutual labels:  actor-critic
DRL in CV
A course on Deep Reinforcement Learning in Computer Vision. Visit Website:
Stars: ✭ 59 (-34.44%)
Mutual labels:  q-learning
yarll
Combining deep learning and reinforcement learning.
Stars: ✭ 84 (-6.67%)
Mutual labels:  sarsa
Learningx
Deep & Classical Reinforcement Learning + Machine Learning Examples in Python
Stars: ✭ 241 (+167.78%)
Mutual labels:  q-learning
Implicit-Q-Learning
PyTorch implementation of the implicit Q-learning algorithm (IQL)
Stars: ✭ 27 (-70%)
Mutual labels:  q-learning
king-pong
Deep Reinforcement Learning Pong Agent, King Pong, he's the best
Stars: ✭ 23 (-74.44%)
Mutual labels:  q-learning
ReinforcementLearning Sutton-Barto Solutions
Solutions and figures for problems from Reinforcement Learning: An Introduction Sutton&Barto
Stars: ✭ 20 (-77.78%)
Mutual labels:  sarsa
Grid royale
A life simulation for exploring social dynamics
Stars: ✭ 252 (+180%)
Mutual labels:  q-learning
off-policy-continuous-control
[DeepRL Workshop, NeurIPS-21] Recurrent Off-policy Baselines for Memory-based Continuous Control (RDPG, RTD3 and RSAC)
Stars: ✭ 29 (-67.78%)
Mutual labels:  actor-critic
Master-Thesis
Deep Reinforcement Learning in Autonomous Driving: the A3C algorithm used to make a car learn to drive in TORCS; Python 3.5, Tensorflow, tensorboard, numpy, gym-torcs, ubuntu, latex
Stars: ✭ 33 (-63.33%)
Mutual labels:  actor-critic
LearnSnake
🐍 AI that learns to play Snake using Q-Learning (Reinforcement Learning)
Stars: ✭ 69 (-23.33%)
Mutual labels:  q-learning
Flow-Shop-Scheduling-Based-On-Reinforcement-Learning-Algorithm
Operations Research Application Project - Flow Shop Scheduling Based On Reinforcement Learning Algorithm
Stars: ✭ 73 (-18.89%)
Mutual labels:  q-learning

java-reinforcement-learning

Package provides java implementation of reinforcement learning algorithms as described in the book "Reinforcement Learning: An Introduction" by Sutton

Build Status Coverage Status

Features

The following reinforcement learning are implemented:

  • R-Learn
  • Q-Learn
  • Q-Learn with eligibility trace
  • SARSA
  • SARSA with eligibility trace
  • Actor-Critic
  • Actor-Critic with eligibility trace

The package also support a number of action-selection strategy:

  • soft-max
  • epsilon-greedy
  • greedy
  • Gibbs-soft-max

Reinforcement Learning

Install

Add the following dependency to your POM file:

<dependency>
  <groupId>com.github.chen0040</groupId>
  <artifactId>java-reinforcement-learning</artifactId>
  <version>1.0.5</version>
</dependency>

Application Samples

The application sample of this library can be found in the following repositories:

Usage

Create Agent

An reinforcement agent, say, Q-Learn agent, can be created by the following java code:

import com.github.chen0040.rl.learning.qlearn.QAgent;

int stateCount = 100;
int actionCount = 10;
QAgent agent = new QAgent(stateCount, actionCount);

The agent created has a state map of 100 states, and 10 different actions for its selection.

For Q-Learn and SARSA, the eligibility trace lambda can be enabled by calling:

agent.enableEligibilityTrace(lambda)

Select Action

At each time step, a action can be selected by the agent, by calling:

int actionId = agent.selectAction().getIndex();

If you want to limits the number of possible action at each states (say the problem restrict the actions avaliable at different state), then call:

Set<Integer> actionsAvailableAtCurrentState = world.getActionsAvailable(agent);
int actionTaken = agent.selectAction(actionsAvailableAtCurrentState).getIndex();

The agent can also change to a different action-selection policy available in com.github.chen0040.rl.actionselection package, for example, the following code switch the action selection policy to soft-max:

agent.getLearner().setActionSelection(SoftMaxActionSelectionStrategy.class.getCanonicalName());

State-Action Update

Once the world state has been updated due to the agent's selected action, its internal state-action Q matrix will be updated by calling:

int newStateId = world.update(agent, actionTaken);
double reward = world.reward(agent);

agent.update(actionTaken, newStateId, reward);

Sample code

Sample code for R-Learn

import com.github.chen0040.rl.learning.rlearn.RAgent;

int stateCount = 100;
int actionCount = 10;
RAgent agent = new RAgent(stateCount, actionCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 agent.update(actionId, newStateId, reward);
}

Alternatively, you can use RLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
RLearner agent = new RLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move move = moves.get(i);
    agent.update(move.oldState, move.action, move.newState, world.getActionsAvailableAtState(nextStateId), move.reward);
}

Sample code for Q-Learn

import com.github.chen0040.rl.learning.qlearn.QAgent;

int stateCount = 100;
int actionCount = 10;
QAgent agent = new QAgent(stateCount, actionCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 agent.update(actionId, newStateId, reward);
}

Alternatively, you can use QLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
QLearner agent = new QLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move move = moves.get(i);
    agent.update(move.oldState, move.action, move.newState, move.reward);
}

Sample code for SARSA

import com.github.chen0040.rl.learning.sarsa.SarsaAgent;

int stateCount = 100;
int actionCount = 10;
SarsaAgent agent = new SarsaAgent(stateCount, actionCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 agent.update(actionId, newStateId, reward);
}

Alternatively, you can use SarsaLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
SarsaLearner agent = new SarsaLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move next_move = moves.get(i);
    if(i != moves.size()-1) {
        next_move = moves.get(i+1);
    }
    Move current_move = moves.get(i);
    agent.update(current_move.oldState, current_move.action, current_move.newState, next_move.action, current_move.reward);
}

Sample code for Actor Critic Model

import com.github.chen0040.rl.learning.actorcritic.ActorCriticAgent;
import com.github.chen0040.rl.utils.Vec;

int stateCount = 100;
int actionCount = 10;
ActorCriticAgent agent = new ActorCriticAgent(stateCount, actionCount);
Vec stateValues = new Vec(stateCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 
 System.out.println("World state values changed ...");
 for(int stateId = 0; stateId < stateCount; ++stateId){
    stateValues.set(stateId, random.nextDouble());
 }
    
 agent.update(actionId, newStateId, reward, stateValues);
}

Alternatively, you can use ActorCriticLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
SarsaLearner agent = new SarsaLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move next_move = moves.get(i);
    if(i != moves.size()-1) {
        next_move = moves.get(i+1);
    }
    Move current_move = moves.get(i);
    agent.update(current_move.oldState, current_move.action, current_move.newState, next_move.action, current_move.reward);
}

Save and Load RL models

To save the trained RL model (say QLeanrer):

QLearner learner = new QLearner(stateCount, actionCount);
train(learner);
String json = learner.toJson();

To load the trained RL model from json:

QLearner learner = QLearn.fromJson(json);
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].