All Projects → scottgigante → m-phate

scottgigante / m-phate

Licence: GPL-3.0 license
Multislice PHATE for tensor embeddings

Programming Languages

python
139335 projects - #7 most used programming language
shell
77523 projects

Projects that are alternatives of or similar to m-phate

redunet paper
Official NumPy Implementation of Deep Networks from the Principle of Rate Reduction (2021)
Stars: ✭ 49 (-9.26%)
Mutual labels:  interpretable-deep-learning
self critical vqa
Code for NeurIPS 2019 paper ``Self-Critical Reasoning for Robust Visual Question Answering''
Stars: ✭ 39 (-27.78%)
Mutual labels:  interpretable-deep-learning
Awesome Machine Learning Interpretability
A curated list of awesome machine learning interpretability resources.
Stars: ✭ 2,404 (+4351.85%)
Mutual labels:  interpretable-deep-learning
Pytorch Grad Cam
Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM
Stars: ✭ 3,814 (+6962.96%)
Mutual labels:  interpretable-deep-learning
Layerwise-Relevance-Propagation
Implementation of Layerwise Relevance Propagation for heatmapping "deep" layers
Stars: ✭ 78 (+44.44%)
Mutual labels:  interpretable-deep-learning
CNN-Units-in-NLP
✂️ Repository for our ICLR 2019 paper: Discovery of Natural Language Concepts in Individual Units of CNNs
Stars: ✭ 26 (-51.85%)
Mutual labels:  interpretable-deep-learning
InterpretableCNN
No description or website provided.
Stars: ✭ 36 (-33.33%)
Mutual labels:  interpretable-deep-learning
path explain
A repository for explaining feature attributions and feature interactions in deep neural networks.
Stars: ✭ 151 (+179.63%)
Mutual labels:  interpretable-deep-learning
ProtoTree
ProtoTrees: Neural Prototype Trees for Interpretable Fine-grained Image Recognition, published at CVPR2021
Stars: ✭ 47 (-12.96%)
Mutual labels:  interpretable-deep-learning
ShapleyExplanationNetworks
Implementation of the paper "Shapley Explanation Networks"
Stars: ✭ 62 (+14.81%)
Mutual labels:  interpretable-deep-learning
rl singing voice
Unsupervised Representation Learning for Singing Voice Separation
Stars: ✭ 18 (-66.67%)
Mutual labels:  interpretable-deep-learning
ISeeU
ISeeU: Visually interpretable deep learning for mortality prediction inside the ICU
Stars: ✭ 20 (-62.96%)
Mutual labels:  interpretable-deep-learning
deep-explanation-penalization
Code for using CDEP from the paper "Interpretations are useful: penalizing explanations to align neural networks with prior knowledge" https://arxiv.org/abs/1909.13584
Stars: ✭ 110 (+103.7%)
Mutual labels:  interpretable-deep-learning

M-PHATE

Latest PyPi version Travis CI Build Coverage Status arXiv Preprint Twitter GitHub stars Video abstract

Demonstration M-PHATE plot

Multislice PHATE (M-PHATE) is a dimensionality reduction algorithm for the visualization of time-evolving data. To learn more about M-PHATE, you can read our preprint on arXiv in which we apply it to the evolution of neural networks over the course of training. Above we show a demonstration of M-PHATE applied to a 3-layer MLP over 300 epochs of training, colored by epoch (left), hidden layer (center) and the digit label that most strongly activates each hidden unit (right). Below, you see the same network with dropout applied in training embedded in 3D, also colored by most active unit.

Table of Contents

3D rotating gif

How it works

Multislice PHATE (M-PHATE) combines a novel multislice kernel construction with the PHATE visualization. Our kernel captures the dynamics of an evolving graph structure, that when when visualized, gives unique intuition about the evolution of a system; in our preprint, we show this applied to a neural network over the course of training and re-training. We compare M-PHATE to other dimensionality reduction techniques, showing that the combined construction of the multislice kernel and the use of PHATE provide significant improvements to visualization. In two vignettes, we demonstrate the use M-PHATE on established training tasks and learning methods in continual learning, and in regularization techniques commonly used to improve generalization performance.

The multislice kernel used in M-PHATE consists of building graphs over time slices of data (e.g. epochs in neural network training) and then connecting these slices by connecting each point to itself over time, weighted by its similarity. The result is a highly sparse, structured kernel which provides insight into the evolving structure of the data.

For more details, check out our NeurIPS publication, read the tweetorial or have a look at our poster.

Example of multislice graph

Example of multislice kernel

Installation

Install from pypi

pip install --user m-phate

Install from source

pip install --user git+https://github.com/scottgigante/m-phate.git

Usage

Basic usage example

Below we apply M-PHATE to simulated data of 50 points undergoing random motion.

import numpy as np
import m_phate
import scprep

# create fake data
n_time_steps = 100
n_points = 50
n_dim = 25
np.random.seed(42)
data = np.cumsum(np.random.normal(0, 1, (n_time_steps, n_points, n_dim)), axis=0)

# embedding
m_phate_op = m_phate.M_PHATE()
m_phate_data = m_phate_op.fit_transform(data)

# plot
time = np.repeat(np.arange(n_time_steps), n_points)
scprep.plot.scatter2d(m_phate_data, c=time, ticks=False, label_prefix="M-PHATE")

Example embedding

Network training

To apply M-PHATE to neural networks, we provide helper classes to store the samples from the network during training. In order to use these, you must install tensorflow and keras.

import numpy as np

import keras
import scprep

import m_phate
import m_phate.train
import m_phate.data

# load data
x_train, x_test, y_train, y_test = m_phate.data.load_mnist()

# select trace examples
trace_idx = [np.random.choice(np.argwhere(y_test[:, i] == 1).flatten(),
                              10, replace=False)
             for i in range(10)]
trace_data = x_test[np.concatenate(trace_idx)]

# build neural network
lrelu = keras.layers.LeakyReLU(alpha=0.1)
inputs = keras.layers.Input(
    shape=(x_train.shape[1],), dtype='float32', name='inputs')
h1 = keras.layers.Dense(128, activation=lrelu, name='h1')(inputs)
h2 = keras.layers.Dense(64, activation=lrelu, name='h2')(h1)
h3 = keras.layers.Dense(128, activation=lrelu, name='h3')(h2)
outputs = keras.layers.Dense(10, activation='softmax', name='output_all')(h3)

# build trace model helper
model_trace = keras.models.Model(inputs=inputs, outputs=[h1, h2, h3])
trace = m_phate.train.TraceHistory(trace_data, model_trace)

# compile network
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='categorical_crossentropy',
              metrics=['categorical_accuracy', 'categorical_crossentropy'])

# train network
model.fit(x_train, y_train, batch_size=128, epochs=200,
          verbose=1, callbacks=[trace],
          validation_data=(x_test,
                           y_test))

# extract trace data
trace_data = np.array(trace.trace)
epoch = np.repeat(np.arange(trace_data.shape[0]), trace_data.shape[1])

# apply M-PHATE
m_phate_op = m_phate.M_PHATE()
m_phate_data = m_phate_op.fit_transform(trace_data)

# plot the result
scprep.plot.scatter2d(m_phate_data, c=epoch, ticks=False,
                      label_prefix="M-PHATE")

Example notebooks

For detailed examples, see our sample notebooks in keras and tensorflow in examples:

Parameter tuning

The key to tuning the parameters of M-PHATE is essentially balancing the tradeoff between interslice connectivity and intraslice connectivity. This is primarily achieved with interslice_knn and intraslice_knn. You can see an example of the effects of parameter tuning in this notebook.

Figure reproduction

We provide scripts to reproduce all of the empirical figures in the preprint.

To run them:

git clone https://github.com/scottgigante/m-phate
cd m-phate
pip install --user .

# change this if you want to store the data elsewhere
DATA_DIR=~/data/checkpoints/m_phate

# choose between cifar and mnist
DATASET="mnist"
EXTRA_ARGS="--dataset ${DATASET}"

# remove to use validation data
EXTRA_ARGS="${EXTRA_ARGS} --sample-train-data"

chmod +x scripts/generalization/generalization_train.sh
chmod +x scripts/task_switching/classifier_mnist_task_switch_train.sh

./scripts/generalization/generalization_train.sh "${DATA_DIR}" "${EXTRA_ARGS}"
./scripts/task_switching/classifier_mnist_task_switch_train.sh "${DATA_DIR}" "${EXTRA_ARGS}"

python scripts/demonstration_plot.py "${DATA_DIR}" "${DATASET}"
python scripts/comparison_plot.py "${DATA_DIR}" "${DATASET}"
python scripts/generalization_plot.py "${DATA_DIR}" "${DATASET}"
python scripts/task_switch_plot.py "${DATA_DIR}" "${DATASET}"

TODO

  • Provide support for PyTorch
  • Notebook examples for:
    • Classification, pytorch
    • Autoencoder, pytorch
  • Build readthedocs page

Help

If you have any questions, please feel free to open an issue.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].