All Projects → inouye-lab → ShapleyExplanationNetworks

inouye-lab / ShapleyExplanationNetworks

Licence: MIT license
Implementation of the paper "Shapley Explanation Networks"

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to ShapleyExplanationNetworks

ProtoTree
ProtoTrees: Neural Prototype Trees for Interpretable Fine-grained Image Recognition, published at CVPR2021
Stars: ✭ 47 (-24.19%)
Mutual labels:  interpretable-deep-learning, explainable-ai, interpretable-machine-learning
ml-fairness-framework
FairPut - Machine Learning Fairness Framework with LightGBM — Explainability, Robustness, Fairness (by @firmai)
Stars: ✭ 59 (-4.84%)
Mutual labels:  explainable-ai, interpretable-machine-learning
fastshap
Fast approximate Shapley values in R
Stars: ✭ 79 (+27.42%)
Mutual labels:  explainable-ai, interpretable-machine-learning
Interpret
Fit interpretable models. Explain blackbox machine learning.
Stars: ✭ 4,352 (+6919.35%)
Mutual labels:  explainable-ai, interpretable-machine-learning
path explain
A repository for explaining feature attributions and feature interactions in deep neural networks.
Stars: ✭ 151 (+143.55%)
Mutual labels:  interpretable-deep-learning, explainable-ai
self critical vqa
Code for NeurIPS 2019 paper ``Self-Critical Reasoning for Robust Visual Question Answering''
Stars: ✭ 39 (-37.1%)
Mutual labels:  interpretable-deep-learning, explainable-ai
Layerwise-Relevance-Propagation
Implementation of Layerwise Relevance Propagation for heatmapping "deep" layers
Stars: ✭ 78 (+25.81%)
Mutual labels:  interpretable-deep-learning, interpretable-machine-learning
mllp
The code of AAAI 2020 paper "Transparent Classification with Multilayer Logical Perceptrons and Random Binarization".
Stars: ✭ 15 (-75.81%)
Mutual labels:  explainable-ai, interpretable-machine-learning
deep-explanation-penalization
Code for using CDEP from the paper "Interpretations are useful: penalizing explanations to align neural networks with prior knowledge" https://arxiv.org/abs/1909.13584
Stars: ✭ 110 (+77.42%)
Mutual labels:  interpretable-deep-learning, explainable-ai
Awesome Machine Learning Interpretability
A curated list of awesome machine learning interpretability resources.
Stars: ✭ 2,404 (+3777.42%)
Mutual labels:  interpretable-deep-learning, interpretable-machine-learning
XAIatERUM2020
Workshop: Explanation and exploration of machine learning models with R and DALEX at eRum 2020
Stars: ✭ 52 (-16.13%)
Mutual labels:  explainable-ai, interpretable-machine-learning
fast-tsetlin-machine-with-mnist-demo
A fast Tsetlin Machine implementation employing bit-wise operators, with MNIST demo.
Stars: ✭ 58 (-6.45%)
Mutual labels:  explainable-ai
hierarchical-dnn-interpretations
Using / reproducing ACD from the paper "Hierarchical interpretations for neural network predictions" 🧠 (ICLR 2019)
Stars: ✭ 110 (+77.42%)
Mutual labels:  explainable-ai
meg
Molecular Explanation Generator
Stars: ✭ 14 (-77.42%)
Mutual labels:  explainable-ai
CGCF-ConfGen
🧪 Learning Neural Generative Dynamics for Molecular Conformation Generation (ICLR 2021)
Stars: ✭ 41 (-33.87%)
Mutual labels:  iclr2021
Awesome-Vision-Transformer-Collection
Variants of Vision Transformer and its downstream tasks
Stars: ✭ 124 (+100%)
Mutual labels:  explainable-ai
ShapML.jl
A Julia package for interpretable machine learning with stochastic Shapley values
Stars: ✭ 63 (+1.61%)
Mutual labels:  interpretable-machine-learning
dlime experiments
In this work, we propose a deterministic version of Local Interpretable Model Agnostic Explanations (LIME) and the experimental results on three different medical datasets shows the superiority for Deterministic Local Interpretable Model-Agnostic Explanations (DLIME).
Stars: ✭ 21 (-66.13%)
Mutual labels:  explainable-ai
m-phate
Multislice PHATE for tensor embeddings
Stars: ✭ 54 (-12.9%)
Mutual labels:  interpretable-deep-learning
Transformer-MM-Explainability
[ICCV 2021- Oral] Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
Stars: ✭ 484 (+680.65%)
Mutual labels:  explainable-ai

Shapley Explanation Networks

Implementation of the paper "Shapley Explanation Networks" at ICLR 2021. Note that this repo heavily uses the experimental feature of named tensors in PyTorch. As it was really confusing to implement the ideas for the authors, we find it tremendously easier to use this feature.

Dependencies

For running only ShapNets, one would mostly only need PyTorch, NumPy, and SciPy.

Usage

For a Shapley Module:

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule

b_size = 3
features = 4
out = 1
dims = ModuleDimensions(
    features=features,
    in_channel=1,
    out_channel=out
)

sm = ShapleyModule(
    inner_function=nn.Linear(features, out),
    dimensions=dims
)
sm(torch.randn(b_size, features), explain=True)

For a Shallow ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, OverlappingShallowShapleyNetwork

batch_size = 32
class_num = 10
dim = 32

overlapping_modules = [
    ShapleyModule(
        inner_function=nn.Sequential(nn.Linear(2, class_num)),
        dimensions=ModuleDimensions(
            features=2, in_channel=1, out_channel=class_num
        ),
    ) for _ in range(dim * (dim - 1) // 2)
]
shallow_shapnet = OverlappingShallowShapleyNetwork(
    list_modules=overlapping_modules
)
inputs = torch.randn(batch_size, dim, ), )
shallow_shapnet(torch.randn(batch_size, dim, ), )
output, bias = shallow_shapnet(inputs, explain=True, )

For a Deep ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, ShallowShapleyNetwork, DeepShapleyNetwork

dim = 32
dim_input_channels = 1
class_num = 10
inputs = torch.randn(32, dim, ), )


dims = ModuleDimensions(
    features=dim,
    in_channel=dim_input_channels,
    out_channel=class_num
)
deep_shapnet = DeepShapleyNetwork(
    list_shapnets=[
        ShallowShapleyNetwork(
            module_dict=nn.ModuleDict({
                "(0, 2)": ShapleyModule(
                    inner_function=nn.Linear(2, class_num),
                    dimensions=ModuleDimensions(
                        features=2, in_channel=1, out_channel=class_num
                    )
                )},
            ),
            dimensions=ModuleDimensions(dim, 1, class_num)
        ),
    ],
)
deep_shapnet(inputs)
outputs = deep_shapnet(inputs, explain=True, )

For a vision model:

import numpy as np
import torch
import torch.nn as nn

# =============================================================================
# Imports {\sc ShapNet}
# =============================================================================
from ShapNet import DeepConvShapNet, ShallowConvShapleyNetwork, ShapleyModule
from ShapNet.utils import ModuleDimensions, NAME_HEIGHT, NAME_WIDTH, \
    process_list_sizes

num_channels = 3
num_classes = 10
height = 32
width = 32
list_channels = [3, 16, 10]
pruning = [0.2, 0.]
kernel_sizes = process_list_sizes([2, (1, 3), ])
dilations = process_list_sizes([1, 2])
paddings = process_list_sizes([0, 0])
strides = process_list_sizes([1, 1])

args = {
    "list_shapnets": [
        ShallowConvShapleyNetwork(
            shapley_module=ShapleyModule(
                inner_function=nn.Sequential(
                    nn.Linear(
                        np.prod(kernel_sizes[i]) * list_channels[i],
                        list_channels[i + 1]),
                    nn.LeakyReLU()
                ),
                dimensions=ModuleDimensions(
                    features=int(np.prod(kernel_sizes[i])),
                    in_channel=list_channels[i],
                    out_channel=list_channels[i + 1])
            ),
            reference_values=None,
            kernel_size=kernel_sizes[i],
            dilation=dilations[i],
            padding=paddings[i],
            stride=strides[i]
        ) for i in range(len(list_channels) - 1)
    ],
    "reference_values": None,
    "residual": False,
    "named_output": False,
    "pruning": pruning
}

dcs = DeepConvShapNet(**args)

Citation

If this is useful, you could cite our work as

@inproceedings{
wang2021shapley,
title={Shapley Explanation Networks},
author={Rui Wang and Xiaoqian Wang and David I. Inouye},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=vsU0efpivw}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].