All Projects → WeiChengTseng → Pytorch-PCGrad

WeiChengTseng / Pytorch-PCGrad

Licence: BSD-3-Clause license
Pytorch reimplementation for "Gradient Surgery for Multi-Task Learning"

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Pytorch-PCGrad

revisiting rainbow
Revisiting Rainbow
Stars: ✭ 71 (-60.34%)
Mutual labels:  deep-reinforcement-learning, rl
Ml In Tf
Get started with Machine Learning in TensorFlow with a selection of good reads and implemented examples!
Stars: ✭ 45 (-74.86%)
Mutual labels:  deep-reinforcement-learning, mnist
Rad
RAD: Reinforcement Learning with Augmented Data
Stars: ✭ 268 (+49.72%)
Mutual labels:  deep-reinforcement-learning, rl
MNIST-multitask
6️⃣6️⃣6️⃣ Reproduce ICLR '18 under-reviewed paper "MULTI-TASK LEARNING ON MNIST IMAGE DATASETS"
Stars: ✭ 34 (-81.01%)
Mutual labels:  mnist, multi-task-learning
Exploration By Disagreement
[ICML 2019] TensorFlow Code for Self-Supervised Exploration via Disagreement
Stars: ✭ 99 (-44.69%)
Mutual labels:  deep-reinforcement-learning, rl
Gym Gazebo2
gym-gazebo2 is a toolkit for developing and comparing reinforcement learning algorithms using ROS 2 and Gazebo
Stars: ✭ 257 (+43.58%)
Mutual labels:  deep-reinforcement-learning, rl
Mushroom Rl
Python library for Reinforcement Learning.
Stars: ✭ 442 (+146.93%)
Mutual labels:  deep-reinforcement-learning, rl
DeepBeerInventory-RL
The code for the SRDQN algorithm to train an agent for the beer game problem
Stars: ✭ 27 (-84.92%)
Mutual labels:  deep-reinforcement-learning, rl
Rlenv.directory
Explore and find reinforcement learning environments in a list of 150+ open source environments.
Stars: ✭ 79 (-55.87%)
Mutual labels:  deep-reinforcement-learning, rl
Muzero General
MuZero
Stars: ✭ 1,187 (+563.13%)
Mutual labels:  deep-reinforcement-learning, rl
Drq
DrQ: Data regularized Q
Stars: ✭ 268 (+49.72%)
Mutual labels:  deep-reinforcement-learning, rl
Pytorch Drl
PyTorch implementations of various Deep Reinforcement Learning (DRL) algorithms for both single agent and multi-agent.
Stars: ✭ 233 (+30.17%)
Mutual labels:  deep-reinforcement-learning, rl
Noreward Rl
[ICML 2017] TensorFlow code for Curiosity-driven Exploration for Deep Reinforcement Learning
Stars: ✭ 1,176 (+556.98%)
Mutual labels:  deep-reinforcement-learning, rl
Aws Robomaker Sample Application Deepracer
Use AWS RoboMaker and demonstrate running a simulation which trains a reinforcement learning (RL) model to drive a car around a track
Stars: ✭ 105 (-41.34%)
Mutual labels:  deep-reinforcement-learning, rl
Learning To Communicate Pytorch
Learning to Communicate with Deep Multi-Agent Reinforcement Learning in PyTorch
Stars: ✭ 236 (+31.84%)
Mutual labels:  deep-reinforcement-learning, rl
ultimate-volleyball
3D RL Volleyball environment built on Unity ML-Agents
Stars: ✭ 60 (-66.48%)
Mutual labels:  deep-reinforcement-learning
awesome-machine-learning-robotics
A curated list of resources about Machine Learning for Robotics
Stars: ✭ 52 (-70.95%)
Mutual labels:  deep-reinforcement-learning
DeepCubeA
Code for DeepCubeA, a Deep Reinforcement Learning algorithm that can learn to solve the Rubik's cube.
Stars: ✭ 92 (-48.6%)
Mutual labels:  deep-reinforcement-learning
digit recognizer
CNN digit recognizer implemented in Keras Notebook, Kaggle/MNIST (0.995).
Stars: ✭ 27 (-84.92%)
Mutual labels:  mnist
Hand-Digits-Recognition
Recognize your own handwritten digits with Tensorflow, embedded in a PyQT5 GUI. The Neural Network was trained on MNIST.
Stars: ✭ 11 (-93.85%)
Mutual labels:  mnist

PyTorch-PCGrad

This repository provide code of reimplementation for Gradient Surgery for Multi-Task Learning in PyTorch 1.6.0.

Setup

Install the required packages via:

pip install -r requirements.txt

Usage

import torch
import torch.nn as nn
import torch.optim as optim
from pcgrad import PCGrad

# wrap your favorite optimizer
optimizer = PCGrad(optim.Adam(net.parameters())) 
losses = [...] # a list of per-task losses
assert len(losses) == num_tasks
optimizer.pc_backward(losses) # calculate the gradient can apply gradient modification
optimizer.step()  # apply gradient step

Training

  • Mulit-MNIST Please run the training script via the following command. Part of implementation is leveraged from https://github.com/intel-isl/MultiObjectiveOptimization

    python main_multi_mnist.py
    

    The result is shown below.

    Method left-digit right-digit
    Jointly Training 90.30 90.01
    PCGrad (this repo.) 95.00 92.00
    PCGrad (official) 96.58 95.50
  • Cifar100-MTL coming soon

Reference

Please cite as:

@article{yu2020gradient,
  title={Gradient surgery for multi-task learning},
  author={Yu, Tianhe and Kumar, Saurabh and Gupta, Abhishek and Levine, Sergey and Hausman, Karol and Finn, Chelsea},
  journal={arXiv preprint arXiv:2001.06782},
  year={2020}
}

@misc{Pytorch-PCGrad,
  author = {Wei-Cheng Tseng},
  title = {WeiChengTseng/Pytorch-PCGrad},
  url = {https://github.com/WeiChengTseng/Pytorch-PCGrad.git},
  year = {2020}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].