All Projects → jacobgil → Pytorch Explain Black Box

jacobgil / Pytorch Explain Black Box

PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Pytorch Explain Black Box

Similarity-Adaptive-Deep-Hashing
Unsupervised Deep Hashing with Similarity-Adaptive and Discrete Optimization (TPAMI2018)
Stars: ✭ 18 (-93.73%)
Mutual labels:  deeplearning
Jejunet
Real-Time Video Segmentation on Mobile Devices with DeepLab V3+, MobileNet V2. Worked on the project in 🏝 Jeju island
Stars: ✭ 258 (-10.1%)
Mutual labels:  deeplearning
Python web crawler da ml dl
python从最基础的语法历经网络基础、前端基础、后端基础和爬虫与数据基础走向机器学习
Stars: ✭ 272 (-5.23%)
Mutual labels:  deeplearning
WSCNNTDSaliency
[BMVC17] Weakly Supervised Saliency Detection with A Category-Driven Map Generator
Stars: ✭ 19 (-93.38%)
Mutual labels:  deeplearning
deep sort
Deep Sort algorithm C++ version
Stars: ✭ 60 (-79.09%)
Mutual labels:  deeplearning
Pytorch Correlation Extension
Custom implementation of Corrleation Module
Stars: ✭ 268 (-6.62%)
Mutual labels:  deeplearning
Multi-Face-Comparison
This repo is meant for backend API for face comparision and computer vision. It is built on python flask framework
Stars: ✭ 20 (-93.03%)
Mutual labels:  deeplearning
Mnist center loss pytorch
A PyTorch implementation of center loss on MNIST
Stars: ✭ 278 (-3.14%)
Mutual labels:  deeplearning
Data-Analysis
Different types of data analytics projects : EDA, PDA, DDA, TSA and much more.....
Stars: ✭ 22 (-92.33%)
Mutual labels:  deeplearning
Randwirenn
Pytorch Implementation of: "Exploring Randomly Wired Neural Networks for Image Recognition"
Stars: ✭ 270 (-5.92%)
Mutual labels:  deeplearning
HistoGAN
Reference code for the paper HistoGAN: Controlling Colors of GAN-Generated and Real Images via Color Histograms (CVPR 2021).
Stars: ✭ 158 (-44.95%)
Mutual labels:  deeplearning
Yolov5-deepsort-driverDistracted-driving-behavior-detection
基于深度学习的驾驶员分心驾驶行为(疲劳+危险行为)预警系统使用YOLOv5+Deepsort实现驾驶员的危险驾驶行为的预警监测
Stars: ✭ 107 (-62.72%)
Mutual labels:  deeplearning
In Prestissimo
A very fast neural network computing framework optimized for mobile platforms.QQ group: 676883532 【验证信息输:绝影】
Stars: ✭ 268 (-6.62%)
Mutual labels:  deeplearning
recurrent-defocus-deblurring-synth-dual-pixel
Reference github repository for the paper "Learning to Reduce Defocus Blur by Realistically Modeling Dual-Pixel Data". We propose a procedure to generate realistic DP data synthetically. Our synthesis approach mimics the optical image formation found on DP sensors and can be applied to virtual scenes rendered with standard computer software. Lev…
Stars: ✭ 30 (-89.55%)
Mutual labels:  deeplearning
Shendusuipian
To know stats by heart
Stars: ✭ 275 (-4.18%)
Mutual labels:  deeplearning
Deep-Learning
It contains the coursework and the practice I have done while learning Deep Learning.🚀 👨‍💻💥 🚩🌈
Stars: ✭ 21 (-92.68%)
Mutual labels:  deeplearning
Fixmatch Pytorch
Unofficial PyTorch implementation of "FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence"
Stars: ✭ 259 (-9.76%)
Mutual labels:  deeplearning
Holy Edge
Holistically-Nested Edge Detection
Stars: ✭ 277 (-3.48%)
Mutual labels:  deeplearning
Bert Ch Ner
基于BERT的中文命名实体识别
Stars: ✭ 274 (-4.53%)
Mutual labels:  deeplearning
Deeplearning.ai Assignments
Stars: ✭ 268 (-6.62%)
Mutual labels:  deeplearning

PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

The paper: https://arxiv.org/abs/1704.03296

What makes the deep learning network think the image label is 'pug, pug-dog' and 'tabby, tabby cat':

Dog Cat

A perturbation of the dog that caused the dog category score to vanish:

Perturbed

What makes the deep learning network think the image label is 'flute, transverse flute':

Flute


Usage: python explain.py <path_to_image>

This is a PyTorch impelentation of

"Interpretable Explanations of Black Boxes by Meaningful Perturbation. Ruth Fong, Andrea Vedaldi" with some deviations.

This uses VGG19 from torchvision. It will be downloaded when used for the first time.

This learns a mask of pixels that explain the result of a black box. The mask is learned by posing an optimization problem and solving directly for the mask values.

This is different than other visualization techniques like Grad-CAM that use heuristics like high positive gradient values as an indication of relevance to the network score.

In our case the black box is the VGG19 model, but this can use any differentiable model.


How it works

Equation

Taken from the paper https://arxiv.org/abs/1704.03296

The goal is to solve for a mask that explains why did the network output a score for a certain category.

We create a low resolution (28x28) mask, and use it to perturb the input image to a deep learning network.

The perturbation combines a blurred version of the image, the regular image, and the up-sampled mask.

Wherever the mask contains low values, the input image will become more blurry.

We want to optimize for the next properties:

  1. When using the mask to blend the input image and it's blurred versions, the score of the target category should drop significantly. The evidence of the category should be removed!
  2. The mask should be sparse. Ideally the mask should be the minimal possible mask to drop the category score. This translates to a L1(1 - mask) term in the cost function.
  3. The mask should be smooth. This translates to a total variation regularization in the cost function.
  4. The mask shouldn't over-fit the network. Since the network activations might contain a lot of noise, it can be easy for the mask to just learn random values that cause the score to drop without being visually coherent. In addition to the other terms, this translates to solving for a lower resolution 28x28 mask.

Deviations from the paper

The paper uses a gaussian kernel with a sigma that is modulated by the value of the mask. This is computational costly to compute since the mask values are updated during the iterations, meaning we need a different kernel for every mask pixel for every iteration.

Initially I tried approximating this by first filtering the image with a filter bank of varying gaussian kernels. Then during optimization, the input image pixel would use the quantized mask value to select an appropriate filter bank output pixel (high mask value -> lower channel).

This was done using the PyTorch variable gather/select_index functions. But it turns out that the gather and select_index functions in PyTorch are not differentiable by the indexes.

Instead, we just compute a perturbed image once, and then blend the image and the perturbed image using:

input_image = (1 - mask) * image + mask * perturbed_image

And it works well in practice.

The perturbed image here is the average of the gaussian and median blurred image, but this can really be changed to many other combinations (try it out and find something better!).

Also now gaussian noise with a sigma of 0.2 is added to the preprocssed image at each iteration, inspired by google's SmoothGradient.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].