All Projects → pemami4911 → Neural Combinatorial Rl Pytorch

pemami4911 / Neural Combinatorial Rl Pytorch

Licence: mit
PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning https://arxiv.org/abs/1611.09940

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Neural Combinatorial Rl Pytorch

Seq2seq Summarizer
Pointer-generator reinforced seq2seq summarization in PyTorch
Stars: ✭ 306 (-6.99%)
Mutual labels:  reinforcement-learning, seq2seq
Text summurization abstractive methods
Multiple implementations for abstractive text summurization , using google colab
Stars: ✭ 359 (+9.12%)
Mutual labels:  reinforcement-learning, seq2seq
Mlds2018spring
Machine Learning and having it Deep and Structured (MLDS) in 2018 spring
Stars: ✭ 124 (-62.31%)
Mutual labels:  reinforcement-learning, seq2seq
Baby Steps Of Rl Ja
Pythonで学ぶ強化学習 -入門から実践まで- サンプルコード
Stars: ✭ 302 (-8.21%)
Mutual labels:  reinforcement-learning
Conv seq2seq
A tensorflow implementation of Fairseq Convolutional Sequence to Sequence Learning(Gehring et al. 2017)
Stars: ✭ 304 (-7.6%)
Mutual labels:  seq2seq
Openai lab
An experimentation framework for Reinforcement Learning using OpenAI Gym, Tensorflow, and Keras.
Stars: ✭ 313 (-4.86%)
Mutual labels:  reinforcement-learning
Nl2bash
Generating bash command from natural language https://arxiv.org/abs/1802.08979
Stars: ✭ 325 (-1.22%)
Mutual labels:  seq2seq
Pytorch Trpo
PyTorch implementation of Trust Region Policy Optimization
Stars: ✭ 303 (-7.9%)
Mutual labels:  reinforcement-learning
Pythonlinearnonlinearcontrol
PythonLinearNonLinearControl is a library implementing the linear and nonlinear control theories in python.
Stars: ✭ 318 (-3.34%)
Mutual labels:  reinforcement-learning
Elf
ELF: a platform for game research with AlphaGoZero/AlphaZero reimplementation
Stars: ✭ 3,240 (+884.8%)
Mutual labels:  reinforcement-learning
A3c trading
Trading with recurrent actor-critic reinforcement learning
Stars: ✭ 305 (-7.29%)
Mutual labels:  reinforcement-learning
Gdrl
Grokking Deep Reinforcement Learning
Stars: ✭ 304 (-7.6%)
Mutual labels:  reinforcement-learning
Reinforcement Learning
Learn Deep Reinforcement Learning in 60 days! Lectures & Code in Python. Reinforcement Learning + Deep Learning
Stars: ✭ 3,329 (+911.85%)
Mutual labels:  reinforcement-learning
Tetris Deep Q Learning Pytorch
Deep Q-learning for playing tetris game
Stars: ✭ 322 (-2.13%)
Mutual labels:  reinforcement-learning
Deeprl Tensorflow2
🐋 Simple implementations of various popular Deep Reinforcement Learning algorithms using TensorFlow2
Stars: ✭ 319 (-3.04%)
Mutual labels:  reinforcement-learning
Self Driving Truck
Self-Driving Truck in Euro Truck Simulator 2, trained via Reinforcement Learning
Stars: ✭ 307 (-6.69%)
Mutual labels:  reinforcement-learning
Dynamic Seq2seq
seq2seq中文聊天机器人
Stars: ✭ 303 (-7.9%)
Mutual labels:  seq2seq
Seq2seq chatbot
基于seq2seq模型的简单对话系统的tf实现,具有embedding、attention、beam_search等功能,数据集是Cornell Movie Dialogs
Stars: ✭ 308 (-6.38%)
Mutual labels:  seq2seq
Reward Learning Rl
[RSS 2019] End-to-End Robotic Reinforcement Learning without Reward Engineering
Stars: ✭ 310 (-5.78%)
Mutual labels:  reinforcement-learning
Raisimlib
RAISIM, A PHYSICS ENGINE FOR ROBOTICS AND AI RESEARCH
Stars: ✭ 328 (-0.3%)
Mutual labels:  reinforcement-learning

neural-combinatorial-rl-pytorch

PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

I have implemented the basic RL pretraining model with greedy decoding from the paper. An implementation of the supervised learning baseline model is available here. Instead of a critic network, I got my results below on TSP from using an exponential moving average critic. The critic network is simply commented out in my code right now. From correspondence with a few others, it was determined that the exponential moving average critic significantly helped improve results.

My implementation uses a stochastic decoding policy in the pointer network, realized via PyTorch's torch.multinomial(), during training, and beam search (not yet finished, only supports 1 beam a.k.a. greedy) for decoding when testing the model.

Currently, there is support for a sorting task and the planar symmetric Euclidean TSP.

See main.sh for an example of how to run the code.

Use the --load_path $LOAD_PATH and --is_train False flags to load a saved model.

To load a saved model and view the pointer network's attention layer, also use the --plot_attention True flag.

Please, feel free to notify me if you encounter any errors, or if you'd like to submit a pull request to improve this implementation.

Adding other tasks

This implementation can be extended to support other combinatorial optimization problems. See sorting_task.py and tsp_task.py for examples on how to add. The key thing is to provide a dataset class and a reward function that takes in a sample solution, selected by the pointer network from the input, and returns a scalar reward. For the sorting task, the agent received a reward proportional to the length of the longest strictly increasing subsequence in the decoded output (e.g., [1, 3, 5, 2, 4] -> 3/5 = 0.6).

Dependencies

  • Python=3.6 (should be OK with v >= 3.4)
  • PyTorch=0.2 and 0.3
  • tqdm
  • matplotlib
  • tensorboard_logger

PyTorch 0.4 compatibility is available on branch pytorch-0.4.

TSP Results

Results for 1 random seed over 50 epochs (each epoch is 10,000 batches of size 128). After each epoch, I validated performance on 1000 held out graphs. I used the same hyperparameters from the paper, as can be seen in main.sh. The dashed line shows the value indicated in Table 2 of Bello, et. al for comparison. The log scale x axis for the training reward is used to show how the tour length drops early on.

TSP 20 Train TSP 20 Val TSP 50 Train TSP 50 Val

Sort Results

I trained a model on sort10 for 4 epochs of 1,000,000 randomly generated samples. I tested it on a dataset of size 10,000. Then, I tested the same model on sort15 and sort20 to test the generalization capabilities.

Test results on 10,000 samples (A reward of 1.0 means the network perfectly sorted the input):

task average reward variance
sort10 0.9966 0.0005
sort15 0.7484 0.0177
sort20 0.5586 0.0060

Example prediction on sort10:

input: [4, 7, 5, 0, 3, 2, 6, 8, 9, 1]
output: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Attention visualization

Plot the pointer network's attention layer with the argument --plot_attention True

TODO

  • [ ] Add RL pretraining-Sampling
  • [ ] Add RL pretraining-Active Search
  • [ ] Active Search
  • [ ] Asynchronous training a la A3C
  • [X] Refactor USE_CUDA variable
  • [ ] Finish implementing beam search decoding to support > 1 beam
  • [ ] Add support for variable length inputs

Acknowledgements

Special thanks to the repos devsisters/neural-combinatorial-rl-tensorflow and MaximumEntropy/Seq2Seq-PyTorch for getting me started, and @ricgama for figuring out that weird bug with clone()

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].