All Projects → Tramac → pytorch-cam

Tramac / pytorch-cam

Licence: Apache-2.0 License
Class Activation Map (CAM) Visualizations in PyTorch.

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to pytorch-cam

Relevance-CAM
The official code of Relevance-CAM
Stars: ✭ 21 (-22.22%)
Mutual labels:  class-activation-map
kserve
Serverless Inferencing on Kubernetes
Stars: ✭ 1,621 (+5903.7%)
Mutual labels:  model-interpretability
Active-Explainable-Classification
A set of tools for leveraging pre-trained embeddings, active learning and model explainability for effecient document classification
Stars: ✭ 28 (+3.7%)
Mutual labels:  model-interpretability
xai-iml-sota
Interesting resources related to Explainable Artificial Intelligence, Interpretable Machine Learning, Interactive Machine Learning, Human in Loop and Visual Analytics.
Stars: ✭ 51 (+88.89%)
Mutual labels:  model-interpretability

PyTorch-CAM

This project provide a script of class activation map (CAM) visualizations, which can be used for explaining predictions and model interpretability, etc.

Installation

  • Install via pip
$ pip install torchcam
  • Install from source
$ pip install --upgrade git+https://github.com/Tramac/pytorch-cam.git

Usage

from torchcam import open_image, image2batch, int2tensor, getCAM
from torchvision.models import resnet18

img = open_image('./data/cat.jpg', (224, 224), convert_mode='RGB')
input = image2batch(img)
image_class = 284 # cat class in imagenet
target = int2tensor(image_class)

model = resnet18(pretrained=True)
  • Basic
# gradcam
cam = getCAM(model, img, input, target, display=True, save=False)
  • Adavanced

Besides the default gradcam method, these following additional methods are alse available: vanilla_grad, grad_x_input, saliency, integrate_grad, deconv, smooth_grad.

from torchcam import saliency

results = saliency.get_image_saliency_result(model, img, input, target, methods=['smooth_grad', 'vanilla_grad', 'grad_x_input', 'saliency'])
figure = saliency.get_image_saliency_plot(results, display=True, save=False)
  • Yours own model
model = YourModel()

cam = getCAM(model, img, input, target, layer_path=['xxx']) # The end backprop layer key in your model

Result

TODO

  • support more explainer
  • optim code
  • test your own model

Reference

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].