All Projects → sicara → Tf Explain

sicara / Tf Explain

Licence: mit
Interpretability Methods for tf.keras models with Tensorflow 2.x

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Tf Explain

knowledge-neurons
A library for finding knowledge neurons in pretrained transformer models.
Stars: ✭ 72 (-90.77%)
Mutual labels:  interpretability
Neural Backed Decision Trees
Making decision trees competitive with neural networks on CIFAR10, CIFAR100, TinyImagenet200, Imagenet
Stars: ✭ 411 (-47.31%)
Mutual labels:  interpretability
Xai resources
Interesting resources related to XAI (Explainable Artificial Intelligence)
Stars: ✭ 553 (-29.1%)
Mutual labels:  interpretability
SPINE
Code for SPINE - Sparse Interpretable Neural Embeddings. Jhamtani H.*, Pruthi D.*, Subramanian A.*, Berg-Kirkpatrick T., Hovy E. AAAI 2018
Stars: ✭ 44 (-94.36%)
Mutual labels:  interpretability
Interpret
Fit interpretable models. Explain blackbox machine learning.
Stars: ✭ 4,352 (+457.95%)
Mutual labels:  interpretability
Lucid
A collection of infrastructure and tools for research in neural network interpretability.
Stars: ✭ 4,344 (+456.92%)
Mutual labels:  interpretability
summit
🏔️ Summit: Scaling Deep Learning Interpretability by Visualizing Activation and Attribution Summarizations
Stars: ✭ 95 (-87.82%)
Mutual labels:  interpretability
Awesome Federated Learning
Federated Learning Library: https://fedml.ai
Stars: ✭ 624 (-20%)
Mutual labels:  interpretability
Awesome deep learning interpretability
深度学习近年来关于神经网络模型解释性的相关高引用/顶会论文(附带代码)
Stars: ✭ 401 (-48.59%)
Mutual labels:  interpretability
Interpretable machine learning with python
Examples of techniques for training interpretable ML models, explaining ML models, and debugging ML models for accuracy, discrimination, and security.
Stars: ✭ 530 (-32.05%)
Mutual labels:  interpretability
removal-explanations
A lightweight implementation of removal-based explanations for ML models.
Stars: ✭ 46 (-94.1%)
Mutual labels:  interpretability
Facet
Human-explainable AI.
Stars: ✭ 269 (-65.51%)
Mutual labels:  interpretability
Tcav
Code for the TCAV ML interpretability project
Stars: ✭ 442 (-43.33%)
Mutual labels:  interpretability
shapeshop
Towards Understanding Deep Learning Representations via Interactive Experimentation
Stars: ✭ 16 (-97.95%)
Mutual labels:  interpretability
Flashtorch
Visualization toolkit for neural networks in PyTorch! Demo -->
Stars: ✭ 561 (-28.08%)
Mutual labels:  interpretability
neuron-importance-zsl
[ECCV 2018] code for Choose Your Neuron: Incorporating Domain Knowledge Through Neuron Importance
Stars: ✭ 56 (-92.82%)
Mutual labels:  interpretability
Mli Resources
H2O.ai Machine Learning Interpretability Resources
Stars: ✭ 428 (-45.13%)
Mutual labels:  interpretability
Ad examples
A collection of anomaly detection methods (iid/point-based, graph and time series) including active learning for anomaly detection/discovery, bayesian rule-mining, description for diversity/explanation/interpretability. Analysis of incorporating label feedback with ensemble and tree-based detectors. Includes adversarial attacks with Graph Convolutional Network.
Stars: ✭ 641 (-17.82%)
Mutual labels:  interpretability
Xai
XAI - An eXplainability toolbox for machine learning
Stars: ✭ 596 (-23.59%)
Mutual labels:  interpretability
Deeplift
Public facing deeplift repo
Stars: ✭ 512 (-34.36%)
Mutual labels:  interpretability

tf-explain

Pypi Version Build Status Documentation Status Python Versions Tensorflow Versions Code style: black

tf-explain implements interpretability methods as Tensorflow 2.x callbacks to ease neural network's understanding. See Introducing tf-explain, Interpretability for Tensorflow 2.0

Documentation: https://tf-explain.readthedocs.io

Installation

tf-explain is available on PyPi as an alpha release. To install it:

virtualenv venv -p python3.6
pip install tf-explain

tf-explain is compatible with Tensorflow 2.x. It is not declared as a dependency to let you choose between full and standalone-CPU versions. Additionally to the previous install, run:

# For CPU or GPU
pip install tensorflow==2.2.0

Opencv is also a dependency. To install it, run:

# For CPU or GPU
pip install opencv-python

Quickstart

tf-explain offers 2 ways to apply interpretability methods. The full list of methods is the Available Methods section.

On trained model

The best option is probably to load a trained model and apply the methods on it.

# Load pretrained model or your own
model = tf.keras.applications.vgg16.VGG16(weights="imagenet", include_top=True)

# Load a sample image (or multiple ones)
img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224))
img = tf.keras.preprocessing.image.img_to_array(img)
data = ([img], None)

# Start explainer
explainer = GradCAM()
grid = explainer.explain(data, model, class_index=281)  # 281 is the tabby cat index in ImageNet

explainer.save(grid, ".", "grad_cam.png")

During training

If you want to follow your model during the training, you can also use it as a Keras Callback, and see the results directly in TensorBoard.

from tf_explain.callbacks.grad_cam import GradCAMCallback

model = [...]

callbacks = [
    GradCAMCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Available Methods

  1. Activations Visualization
  2. Vanilla Gradients
  3. Gradients*Inputs
  4. Occlusion Sensitivity
  5. Grad CAM (Class Activation Maps)
  6. SmoothGrad
  7. Integrated Gradients

Activations Visualization

Visualize how a given input comes out of a specific activation layer

from tf_explain.callbacks.activations_visualization import ActivationsVisualizationCallback

model = [...]

callbacks = [
    ActivationsVisualizationCallback(
        validation_data=(x_val, y_val),
        layers_name=["activation_1"],
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Vanilla Gradients

Visualize gradients importance on input image

from tf_explain.callbacks.vanilla_gradients import VanillaGradientsCallback

model = [...]

callbacks = [
    VanillaGradientsCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Gradients*Inputs

Variant of Vanilla Gradients ponderating gradients with input values

from tf_explain.callbacks.gradients_inputs import GradientsInputsCallback

model = [...]

callbacks = [
    GradientsInputsCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity

Visualize how parts of the image affects neural network's confidence by occluding parts iteratively

from tf_explain.callbacks.occlusion_sensitivity import OcclusionSensitivityCallback

model = [...]

callbacks = [
    OcclusionSensitivityCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        patch_size=4,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity for Tabby class (stripes differentiate tabby cat from other ImageNet cat classes)

Grad CAM

Visualize how parts of the image affects neural network's output by looking into the activation maps

From Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization

from tf_explain.callbacks.grad_cam import GradCAMCallback

model = [...]

callbacks = [
    GradCAMCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

SmoothGrad

Visualize stabilized gradients on the inputs towards the decision

From SmoothGrad: removing noise by adding noise

from tf_explain.callbacks.smoothgrad import SmoothGradCallback

model = [...]

callbacks = [
    SmoothGradCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        num_samples=20,
        noise=1.,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Integrated Gradients

Visualize an average of the gradients along the construction of the input towards the decision

From Axiomatic Attribution for Deep Networks

from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback

model = [...]

callbacks = [
    IntegratedGradientsCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        n_steps=20,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Roadmap

Contributing

To contribute to the project, please read the dedicated section.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].