All Projects → reiinakano → invariant-risk-minimization

reiinakano / invariant-risk-minimization

Licence: MIT license
Implementation of Invariant Risk Minimization https://arxiv.org/abs/1907.02893

Programming Languages

Jupyter Notebook
11667 projects
python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to invariant-risk-minimization

causeinfer
Machine learning based causal inference/uplift in Python
Stars: ✭ 45 (-40.79%)
Mutual labels:  causality
pycid
Library for graphical models of decision making, based on pgmpy and networkx
Stars: ✭ 64 (-15.79%)
Mutual labels:  causality
cfvqa
[CVPR 2021] Counterfactual VQA: A Cause-Effect Look at Language Bias
Stars: ✭ 96 (+26.32%)
Mutual labels:  causality
ACE
Code for our paper, Neural Network Attributions: A Causal Perspective (ICML 2019).
Stars: ✭ 47 (-38.16%)
Mutual labels:  causality
cfml tools
My collection of causal inference algorithms built on top of accessible, simple, out-of-the-box ML methods, aimed at being explainable and useful in the business context
Stars: ✭ 24 (-68.42%)
Mutual labels:  causality
Dowhy
DoWhy is a Python library for causal inference that supports explicit modeling and testing of causal assumptions. DoWhy is based on a unified language for causal inference, combining causal graphical models and potential outcomes frameworks.
Stars: ✭ 3,480 (+4478.95%)
Mutual labels:  causality
causaldag
Python package for the creation, manipulation, and learning of Causal DAGs
Stars: ✭ 82 (+7.89%)
Mutual labels:  causality
Scribe-py
Regulatory networks with Direct Information in python
Stars: ✭ 28 (-63.16%)
Mutual labels:  causality
CREST
A Causal Relation Schema for Text
Stars: ✭ 19 (-75%)
Mutual labels:  causality
Causal-Deconvolution-of-Networks
Causal Deconvolution of Networks by Algorithmic Generative Models
Stars: ✭ 25 (-67.11%)
Mutual labels:  causality
causal-learn
Causal Discovery for Python. Translation and extension of the Tetrad Java code.
Stars: ✭ 428 (+463.16%)
Mutual labels:  causality
iPerceive
Applying Common-Sense Reasoning to Multi-Modal Dense Video Captioning and Video Question Answering | Python3 | PyTorch | CNNs | Causality | Reasoning | LSTMs | Transformers | Multi-Head Self Attention | Published in IEEE Winter Conference on Applications of Computer Vision (WACV) 2021
Stars: ✭ 52 (-31.58%)
Mutual labels:  causality
Causal Reading Group
We will keep updating the paper list about machine learning + causal theory. We also internally discuss related papers between NExT++ (NUS) and LDS (USTC) by week.
Stars: ✭ 339 (+346.05%)
Mutual labels:  causality
Awesome-Neural-Logic
Awesome Neural Logic and Causality: MLN, NLRL, NLM, etc. 因果推断,神经逻辑,强人工智能逻辑推理前沿领域。
Stars: ✭ 106 (+39.47%)
Mutual labels:  causality
Generalization-Causality
关于domain generalization,domain adaptation,causality,robutness,prompt,optimization,generative model各式各样研究的阅读笔记
Stars: ✭ 482 (+534.21%)
Mutual labels:  causality
ENCO
Official repository of the paper "Efficient Neural Causal Discovery without Acyclicity Constraints"
Stars: ✭ 52 (-31.58%)
Mutual labels:  causality
RECCON
This repository contains the dataset and the PyTorch implementations of the models from the paper Recognizing Emotion Cause in Conversations.
Stars: ✭ 126 (+65.79%)
Mutual labels:  causality

Implementation of Invariant Risk Minimization (https://arxiv.org/abs/1907.02893)

This is an attempt to reproduce the "Colored MNIST" experiments from the paper Invariant Risk Minimization by Arjovsky, et. al.

After trying lots of hyperparameters and various tricks, this implementation achieves close to the paper-reported values (train accuracy > 70%, test accuracy > 60%), though training can be quite unstable depending on the random seed.

The most common failure case is when the gradient norm penalty term is weighted too highly relative to the ERM term. In this case, Φ converges to a function that returns the same value for all inputs. The classifier cannot recover from this point and the accuracy is stuck at 50% for all environments. This makes sense mathematically. If the intermediate representation is the same regardless of input, then any classifier is the ideal classifier, resulting in the penalty gradient being 0.

Another failure case is when the gradient norm penalty is too low and the optimization essentially acts as in ERM (train accuracy > 80%, test accuracy ~10%).

The most important trick I used to get this to work is through scheduled increase of the gradient norm penalty weight. We start at 0 for the gradient norm penalty weight, essentially beginning as ERM, then slowly increase it per epoch.

I use early stopping to stop training once the accuracy on all environments, including the test set, reach an acceptable value. Yes, stopping training based on performance on the test set is not good practice, but I could not find a principled way of stopping training by only observing performance on the training environments. One thing that might be needed when applying IRM to real-world datasets is to leave out a separate environment as a validation set, which we can use for early stopping. The downside is we'll need a minimum of 4 environments to perform IRM (2 train, 1 validation, 1 test).

Feel free to leave an issue if you find a bug or a set of hyperparameters that makes this training stable. Otherwise, let's all just wait for the authors' code, which they say will be available soon. The authors' original code is here: https://github.com/facebookresearch/InvariantRiskMinimization, and apparently posted two months before I started this. For some reason, I wasn't able to find this when I searched the first time. Looks like instead of a gradual increase of the gradient norm penalty, what they do is start at 0 for a few iterations then jump straight up to the higher value for the rest of training. I think the important thing is to make sure the training effectively starts as ERM (0 penalty) before adding in the IRM penalty term.

How to run

You can run the provided notebook in Colaboratory.

Alternatively, you can run main.py locally. There is also an implementation of ERM in main.py if you want to run a baseline. Code depends on Pytorch.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].