All Projects → xuyxu → Soft-Decision-Tree

xuyxu / Soft-Decision-Tree

Licence: BSD-3-Clause license
PyTorch Implementation of "Distilling a Neural Network Into a Soft Decision Tree." Nicholas Frosst, Geoffrey Hinton., 2017.

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Soft-Decision-Tree

Awesome Decision Tree Papers
A collection of research papers on decision, classification and regression trees with implementations.
Stars: ✭ 1,908 (+2747.76%)
Mutual labels:  decision-tree, classification-trees
DecisionTrees
A python implementation of the CART algorithm for decision trees
Stars: ✭ 38 (-43.28%)
Mutual labels:  decision-tree, classification-trees
github-pr-diff-tree
🌲 This action provide a comment that displays the diff of the pull request in a tree format.
Stars: ✭ 31 (-53.73%)
Mutual labels:  tree
Machine-Learning-Models
In This repository I made some simple to complex methods in machine learning. Here I try to build template style code.
Stars: ✭ 30 (-55.22%)
Mutual labels:  decision-tree
Harbol
Harbol is a collection of data structure and miscellaneous libraries, similar in nature to C++'s Boost, STL, and GNOME's GLib
Stars: ✭ 18 (-73.13%)
Mutual labels:  tree
ctl
My variant of the C Template Library
Stars: ✭ 105 (+56.72%)
Mutual labels:  tree
comment tree
Render comment tree like facebook comment - reply
Stars: ✭ 37 (-44.78%)
Mutual labels:  tree
merkle-patricia-tree
☔️🌲 A fast, in-memory optimized merkle patricia tree
Stars: ✭ 22 (-67.16%)
Mutual labels:  tree
js-symbol-tree
Turn any collection of objects into its own efficient tree or linked list using Symbol
Stars: ✭ 86 (+28.36%)
Mutual labels:  tree
gneiss
compositional data analysis toolbox
Stars: ✭ 48 (-28.36%)
Mutual labels:  tree
performant-array-to-tree
Converts an array of items with ids and parent ids to a nested tree in a performant O(n) way. Runs in browsers and Node.js.
Stars: ✭ 193 (+188.06%)
Mutual labels:  tree
BasicExercises
📘 Personal basic practice test playground.
Stars: ✭ 84 (+25.37%)
Mutual labels:  tree
dd-algorithm-examples
Code Snippets of DataStructure & Algorithm & LeetCode Implementations/Solutions for Several Programming Language: Java, JavaScript, Go, Python, Rust, etc. 💫 多语言版本的数据结构与算法实现分析
Stars: ✭ 33 (-50.75%)
Mutual labels:  tree
react-org-tree
😃 a simple organization tree component based on react
Stars: ✭ 72 (+7.46%)
Mutual labels:  tree
multitax
Python package to obtain, parse and explore biological taxonomies (GTDB, NCBI, Silva, Greengenes, OTT)
Stars: ✭ 22 (-67.16%)
Mutual labels:  tree
react-tree
Hierarchical tree component for React in Typescript
Stars: ✭ 174 (+159.7%)
Mutual labels:  tree
svelte-mindmap
Svelte component for MindMap
Stars: ✭ 122 (+82.09%)
Mutual labels:  tree
Low Poly Procedural Trees and Vegetations Project
No description or website provided.
Stars: ✭ 14 (-79.1%)
Mutual labels:  tree
ipld-explorer-cli
🔎 Explore the IPLD directed acyclic graph with your keyboard
Stars: ✭ 22 (-67.16%)
Mutual labels:  tree
Decision-Tree-Implementation
A python 3 implementation of decision tree (machine learning classification algorithm) from scratch
Stars: ✭ 19 (-71.64%)
Mutual labels:  decision-tree

Introduction

This is the pytorch implementation on Soft Decision Tree (SDT), appearing in the paper "Distilling a Neural Network Into a Soft Decision Tree". 2017 (https://arxiv.org/abs/1711.09784).

Quick Start

To run the demo on MNIST, simply use the following commands:

git clone https://github.com/AaronX121/Soft-Decision-Tree.git
cd Soft-Decision-Tree
python main.py

Parameters

Parameter Type Description
input_dim int The number of input dimensions
output_dim int The number of output dimensions (e.g., the number of classes for multi-class classification)
depth int Tree depth, the default is 5
lamda float The coefficient of the regularization term, the default is 1e-3
use_cuda bool Whether use GPU to train / evaluate the model, the default is False

Frequently Asked Questions

  • Training loss suddenly turns into NAN
    • Reason: Sigmoid function used in internal nodes of SDT can be unstable during the training stage, as its gradient is much close to 0 when the absolute value of input is large.
    • Solution: Using a smaller learning rate typically works.
  • Exact training time
    • Setup: MNIST Dataset | Tree Depth: 5 | Epoch: 40 | Batch Size: 128
    • Results: Around 15 minutes on a single RTX-2080ti

Experiment Result on MNIST

After training for 40 epochs with batch_size 128, the best testing accuracy using a SDT model of depth 5, 7 are 94.15 and 94.38, respectively (which is much close to the accuracy reported in raw paper). Related hyper-parameters are available in main.py. Better and more stable performance can be achieved by fine-tuning hyper-parameters.

Below are the testing accuracy curve and training loss curve. The testing accuracy of SDT is evaluated after each training epoch.

MNIST Experiment Result

Package Dependencies

SDT is originally developed in Python 3.6.5. Following are the name and version of packages used in SDT. In my practice, it works fine under different versions of Python or PyTorch.

  • pytorch 0.4.1
  • torchvision 0.2.1
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].