All Projects → chrisby → torchMTL

chrisby / torchMTL

Licence: MIT license
A lightweight module for Multi-Task Learning in pytorch.

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to torchMTL

CPG
Steven C. Y. Hung, Cheng-Hao Tu, Cheng-En Wu, Chien-Hung Chen, Yi-Ming Chan, and Chu-Song Chen, "Compacting, Picking and Growing for Unforgetting Continual Learning," Thirty-third Conference on Neural Information Processing Systems, NeurIPS 2019
Stars: ✭ 91 (+8.33%)
Mutual labels:  multi-task-learning
agegenderLMTCNN
Jia-Hong Lee, Yi-Ming Chan, Ting-Yen Chen, and Chu-Song Chen, "Joint Estimation of Age and Gender from Unconstrained Face Images using Lightweight Multi-task CNN for Mobile Applications," IEEE International Conference on Multimedia Information Processing and Retrieval, MIPR 2018
Stars: ✭ 39 (-53.57%)
Mutual labels:  multi-task-learning
AESRC2020
a deep accent recognition network
Stars: ✭ 35 (-58.33%)
Mutual labels:  mtl
emmental
A deep learning framework for building multimodal multi-task learning systems.
Stars: ✭ 93 (+10.71%)
Mutual labels:  multi-task-learning
NabaztagHackKit
A simple SDK to get your hands dirty with Nabaztag
Stars: ✭ 28 (-66.67%)
Mutual labels:  mtl
cups-rl
Customisable Unified Physical Simulations (CUPS) for Reinforcement Learning. Experiments run on the ai2thor environment (http://ai2thor.allenai.org/) e.g. using A3C, RainbowDQN and A3C_GA (Gated Attention multi-modal fusion) for Task-Oriented Language Grounding (tasks specified by natural language instructions) e.g. "Pick up the Cup or else"
Stars: ✭ 38 (-54.76%)
Mutual labels:  multi-task-learning
Fine-Grained-or-Not
Code release for Your “Flamingo” is My “Bird”: Fine-Grained, or Not (CVPR 2021 Oral)
Stars: ✭ 32 (-61.9%)
Mutual labels:  multi-task-learning
DeepSegmentor
A Pytorch implementation of DeepCrack and RoadNet projects.
Stars: ✭ 152 (+80.95%)
Mutual labels:  multi-task-learning
amta-net
Asymmetric Multi-Task Attention Network for Prostate Bed Segmentation in CT Images
Stars: ✭ 26 (-69.05%)
Mutual labels:  multi-task-learning
temporal-depth-segmentation
Source code (train/test) accompanying the paper entitled "Veritatem Dies Aperit - Temporally Consistent Depth Prediction Enabled by a Multi-Task Geometric and Semantic Scene Understanding Approach" in CVPR 2019 (https://arxiv.org/abs/1903.10764).
Stars: ✭ 20 (-76.19%)
Mutual labels:  multi-task-learning
OmiEmbed
Multi-task deep learning framework for multi-omics data analysis
Stars: ✭ 16 (-80.95%)
Mutual labels:  multi-task-learning
NeuralMerger
Yi-Min Chou, Yi-Ming Chan, Jia-Hong Lee, Chih-Yi Chiu, Chu-Song Chen, "Unifying and Merging Well-trained Deep Neural Networks for Inference Stage," International Joint Conference on Artificial Intelligence (IJCAI), 2018
Stars: ✭ 20 (-76.19%)
Mutual labels:  multi-task-learning
Pytorch-PCGrad
Pytorch reimplementation for "Gradient Surgery for Multi-Task Learning"
Stars: ✭ 179 (+113.1%)
Mutual labels:  multi-task-learning
HyperFace-TensorFlow-implementation
HyperFace
Stars: ✭ 68 (-19.05%)
Mutual labels:  multi-task-learning
multi-task-learning
Multi-task learning smile detection, age and gender classification on GENKI4k, IMDB-Wiki dataset.
Stars: ✭ 154 (+83.33%)
Mutual labels:  multi-task-learning
deep recommenders
Deep Recommenders
Stars: ✭ 214 (+154.76%)
Mutual labels:  multi-task-learning
FOCAL-ICLR
Code for FOCAL Paper Published at ICLR 2021
Stars: ✭ 35 (-58.33%)
Mutual labels:  multi-task-learning
EasyRec
A framework for large scale recommendation algorithms.
Stars: ✭ 599 (+613.1%)
Mutual labels:  multi-task-learning
Mask-YOLO
Inspired from Mask R-CNN to build a multi-task learning, two-branch architecture: one branch based on YOLOv2 for object detection, the other branch for instance segmentation. Simply tested on Rice and Shapes. MobileNet supported.
Stars: ✭ 100 (+19.05%)
Mutual labels:  multi-task-learning
Multi-task-Conditional-Attention-Networks
A prototype version of our submitted paper: Conversion Prediction Using Multi-task Conditional Attention Networks to Support the Creation of Effective Ad Creatives.
Stars: ✭ 21 (-75%)
Mutual labels:  multi-task-learning

torchMTL Logo
A lightweight module for Multi-Task Learning in pytorch.

torchmtl tries to help you composing modular multi-task architectures with minimal effort. All you need is a list of dictionaries in which you define your layers and how they build on each other. From this, torchmtl constructs a meta-computation graph which is executed in each forward pass of the created MTLModel. To combine outputs from multiple layers, simple wrapper functions are provided.

DOI

Installation

torchmtl can be installed via pip:

pip install torchmtl

Quickstart (or find examples here)

Assume you want to train a network on three tasks as shown below.
example

To construct such an architecture with torchmtl, you simply have to define the following list

tasks = [
        {
            'name': "Embed1",
            'layers': Sequential(*[Linear(16, 32), Linear(32, 8)]),
            # No anchor_layer means this layer receives input directly
        },    
        {
            'name': "Embed2",
            'layers': Sequential(*[Linear(16, 32), Linear(32, 8)]),
            # No anchor_layer means this layer receives input directly
        },
        {
            'name': "CatTask",
            'layers': Concat(dim=1),
            'loss_weight': 1.0,
            'anchor_layer': ['Embed1', 'Embed2']
        },
        {
            'name': "Task1",
            'layers': Sequential(*[Linear(8, 32), Linear(32, 1)]),
            'loss': MSELoss(),
            'loss_weight': 1.0,
            'anchor_layer': 'Embed1'            
        },
        {
            'name': "Task2",
            'layers': Sequential(*[Linear(8, 64), Linear(64, 1)]),
            'loss': BCEWithLogitsLoss(),
            'loss_weight': 1.0,
            'anchor_layer': 'Embed2'            
        }, 
        {
            'name': "FNN",
            'layers': Sequential(*[Linear(16, 32), Linear(32, 32)]),
            'anchor_layer': 'CatTask'
        },
        {
            'name': "Task3",
            'layers': Sequential(*[Linear(32, 16), Linear(16, 1)]),
            'anchor_layer': 'FNN',
            'loss': MSELoss(),
            'loss_weight': 'auto',
            'loss_init_val': 1.0
        }
    ]

You can build your final model with the following lines in which you specify from which layers you would like to receive the output.

from torchmtl import MTLModel
model = MTLModel(tasks, output_tasks=['Task1', 'Task2', 'Task3'])

This constructs a meta-computation graph which is executed in each forward pass of your model. You can verify whether the graph was properly built by plotting it using the networkx library:

import networkx as nx
pos = nx.planar_layout(model.g)
nx.draw(model.g, pos, font_size=14, node_color="y", node_size=450, with_labels=True)

graph example

The training loop

You can now enter the typical pytorch training loop and you will have access to everything you need to update your model:

for X, y in data_loader:
    optimizer.zero_grad()

    # Our model will return a list of predictions (from the layers specified in `output_tasks`),
    # loss functions, and regularization parameters (as defined in the tasks variable)
    y_hat, l_funcs, l_weights = model(X)
    
    loss = 0
    # We can now iterate over the tasks and accumulate the losses
    for i in range(len(y_hat)):
        loss += l_weights[i] * l_funcs[i](y_hat[i], y[i])
    
    loss.backward()
    optimizer.step()

Details on the layer definition

There are 6 keys that can be specified (name and layers must always be present):

layers
Basically takes any nn.Module that you can think of. You can plug in a transformer or just a handful of fully connected layers.

anchor_layer
This defines from which other layer this layer receives its input. Take care that the respective dimensions match.

loss
The loss function you want to compute on the output of this layer (l_funcs). Can be set to None or omitted altogether when only access to the layer's output is needed.

loss_weight
The scalar with which you want to regularize the respective loss (l_weights). If set to 'auto', a nn.Parameter is returned which will be updated through backpropagation. Can be set to None or omitted altogether when only access to the layer's output is needed.

loss_init_val
Only needed if loss_weight: 'auto'. The initialization value of the loss_weight parameter.

Wrapping functions

Nodes of the meta-computation graph don't have to be pytorch Modules. They can be concatenation functions or indexing functions that return a certain element of the input. If your X consists of two types of input data X=[X_1, X_2], you can use the SimpleSelect layer to select the X_1 by setting

from torchmtl.wrapping_layers import SimpleSelect
{ ...,
  'layers' = SimpleSelect(selection_axis=0),
  ...
}

It should be trivial to write your own wrapping layers, but I try to provide useful ones with this library. If you have any layers in mind but no time to implement them, feel free to open an issue.

Cite

@misc{torchMTL: A lightweight module for Multi-Task Learning in pytorch,
  author = {Bock, Christian},
  doi = {10.5281/zenodo.4362515},
  url = {https://github.com/chrisby/torchMTL},
  year = {2020}
}

Credits

Logo credits and license: I reused and remixed (moved the dot and rotated the resulting logo a couple times) the pytorch logo from here (accessed through wikimedia commons) which can be used under the Attribution-ShareAlike 4.0 International license. Hence, this logo falls under the same license.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].