All Projects → plemeri → Randwire_tensorflow

plemeri / Randwire_tensorflow

Licence: mit
tensorflow implementation of Exploring Randomly Wired Neural Networks for Image Recognition

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Randwire tensorflow

Caffe2 Ios
Caffe2 on iOS Real-time Demo. Test with Your Own Model and Photos.
Stars: ✭ 221 (+662.07%)
Mutual labels:  classification, deep-neural-networks
Servenet
Service Classification based on Service Description
Stars: ✭ 21 (-27.59%)
Mutual labels:  classification, deep-neural-networks
gans-2.0
Generative Adversarial Networks in TensorFlow 2.0
Stars: ✭ 76 (+162.07%)
Mutual labels:  mnist, cifar10
Paddlex
PaddlePaddle End-to-End Development Toolkit(『飞桨』深度学习全流程开发工具)
Stars: ✭ 3,399 (+11620.69%)
Mutual labels:  classification, deep-neural-networks
Cifar-Autoencoder
A look at some simple autoencoders for the Cifar10 dataset, including a denoising autoencoder. Python code included.
Stars: ✭ 42 (+44.83%)
Mutual labels:  mnist, cifar10
Guesslang
Detect the programming language of a source code
Stars: ✭ 159 (+448.28%)
Mutual labels:  classification, deep-neural-networks
Theano Xnor Net
Theano implementation of XNOR-Net
Stars: ✭ 23 (-20.69%)
Mutual labels:  mnist, cifar10
Pytorch Speech Commands
Speech commands recognition with PyTorch
Stars: ✭ 128 (+341.38%)
Mutual labels:  classification, cifar10
PFL-Non-IID
The origin of the Non-IID phenomenon is the personalization of users, who generate the Non-IID data. With Non-IID (Not Independent and Identically Distributed) issues existing in the federated learning setting, a myriad of approaches has been proposed to crack this hard nut. In contrast, the personalized federated learning may take the advantage…
Stars: ✭ 58 (+100%)
Mutual labels:  mnist, cifar10
deeplearning-mpo
Replace FC2, LeNet-5, VGG, Resnet, Densenet's full-connected layers with MPO
Stars: ✭ 26 (-10.34%)
Mutual labels:  mnist, cifar10
Glasses
High-quality Neural Networks for Computer Vision 😎
Stars: ✭ 138 (+375.86%)
Mutual labels:  classification, deep-neural-networks
Androidtensorflowmnistexample
Android TensorFlow MachineLearning MNIST Example (Building Model with TensorFlow for Android)
Stars: ✭ 449 (+1448.28%)
Mutual labels:  deep-neural-networks, mnist
Invoicenet
Deep neural network to extract intelligent information from invoice documents.
Stars: ✭ 1,886 (+6403.45%)
Mutual labels:  classification, deep-neural-networks
Sparse Evolutionary Artificial Neural Networks
Always sparse. Never dense. But never say never. A repository for the Adaptive Sparse Connectivity concept and its algorithmic instantiation, i.e. Sparse Evolutionary Training, to boost Deep Learning scalability on various aspects (e.g. memory and computational time efficiency, representation and generalization power).
Stars: ✭ 182 (+527.59%)
Mutual labels:  classification, deep-neural-networks
Lsoftmax Pytorch
The Pytorch Implementation of L-Softmax
Stars: ✭ 133 (+358.62%)
Mutual labels:  classification, mnist
image-defect-detection-based-on-CNN
TensorBasicModel
Stars: ✭ 17 (-41.38%)
Mutual labels:  mnist, cifar10
Selfdrivingcar
A collection of all projects pertaining to different layers in the SDC software stack
Stars: ✭ 107 (+268.97%)
Mutual labels:  classification, deep-neural-networks
Dni.pytorch
Implement Decoupled Neural Interfaces using Synthetic Gradients in Pytorch
Stars: ✭ 111 (+282.76%)
Mutual labels:  classification, mnist
Classification Nets
Implement popular models by different DL framework. Such as tensorflow and caffe
Stars: ✭ 17 (-41.38%)
Mutual labels:  classification, cifar10
Rmdl
RMDL: Random Multimodel Deep Learning for Classification
Stars: ✭ 375 (+1193.1%)
Mutual labels:  classification, deep-neural-networks

PWC

RandWire_tensorflow

tensorflow implementation of Exploring Randomly Wired Neural Networks for Image Recognition using Cifar10, MNIST

alt text

Requirements

  • Tensorflow 1.x - GPU version recommended
  • Python 3.x
  • networkx 2.x
  • pyyaml 5.x

Dataset

Please download dataset from this link Both Cifar10 and MNIST dataset are converted into tfrecords format for conveinence. Put train.tfrecords, test.tfrecords files into dataset/cifar10, dataset/mnist

You can create tfrecord file with your own dataset with dataset/dataset_generator.py.

python dataset_generator.py --image_dir ./cifar10/test/images --label_dir ./cifar10/test/labels --output_dir ./cifar10 --output_filename test.tfrecord

Options:

  • --image_dir (str) - directory of your image files. it is recommended to set the name of images to integers like 0.png
  • --label_dir (str) - directory of your label files. it is recommended to set the name of images to integers like 0.txt. label text file must contain class label in integer like 8.
  • --output_dir (str) - directory for output tfrecord file.
  • --outpuf_filename (str) - filename of output tfrecord file.

Experiments

Datasets Model Parameters Accuracy Epoch
CIFAR-10 ResNet110 (Paper) 1.7M 93.57% 300
CIFAR-10 RandWire (my_small_regime) 1.2M 93.64% 60
CIFAR-100 RandWire (my_regime) 8M 74.49% 100

(19.04.18 changed) I trained on Cifar10 dataset and get 6.36 % error on test set. You can download pretrained network from here. Unzip the file and move all files under checkpoint file or your checkpoint directory and try running test script to check the accuracy. The number of parameters used for cifar10 model is aboud 1.2M, which is similar result on ResNet-110 (6.43 %) which used 1.7M parameters.

(19.04.16 added) I trained on Cifar100 dataset and get 74.49% accuracy on test set. You can download pretrained network from same link above.

Training

Cifar 10

python train.py --class_num 10 --image_shape 32 32 3 --stages 4 --channel_count 78 --graph_model ws --graph_param 32 4 0.75 --dropout_rate 0.2 --learning_rate 0.1 --momentum 0.9 --weight_decay 0.0001 --train_set_size 50000 --val_set_size 10000 --batch_size 100 --epochs 100 --checkpoint_dir ./checkpoint --checkpoint_name randwire_cifar10 --train_record_dir ./dataset/cifar10/train.tfrecord --val_record_dir ./dataset/cifar10/test.tfrecord

Options:

  • --class_num (int) - output number of class. Cifar10 has 10 classes.
  • --image_shape (int nargs) - shape of input image. Cifar10 has 32 32 3 shape.
  • --stages (int) - stage (or block) number of randwire network.
  • --channel_count (int) - channel count of randwire network. please refer to the paper
  • --graph_model (str) - currently randwire has 3 random graph models. you can choose from 'er', 'ba' and 'ws'.
  • --graph_param (float nargs) - first value is node count. for 'er' and 'ba', there are one extra parameter so it would be like 32 0.4 or 32 7. for 'ws' there are two extra parameters like above.
  • --learning_rate (float) - initial learning rate
  • --momentum (float) - momentum from momentum optimizer
  • --weight_decay (float) - weight decay factor
  • --train_set_size (int) - number of training data. Cifar10 has 50000 data.
  • --val_set_size (int) - number of validating data. I used test data for validation, so there are 10000 data.
  • --batch_size (int) - size of mini batch
  • --epochs (int) - number of epoch
  • --checkpoint_dir (str) - directory to save checkpoint
  • --checkpoint_name (str) - file name of checkpoint
  • --train_record_dir (str) - file location of training set tfrecord
  • --test_record_dir (str) - file location of test set tfrecord (for validation)

MNIST

python train.py --class_num 10 --image_shape 28 28 1 --stages 4 --channel_count 78 --graph_model ws --graph_param 32 4 0.75 --dropout_rate 0.2 --learning_rate 0.1 --momentum 0.9 --weight_decay 0.0001 --train_set_size 50000 --val_set_size 10000 --batch_size 100 --epochs 100 --checkpoint_dir ./checkpoint --checkpoint_name randwire_mnist --train_record_dir ./dataset/mnist/train.tfrecord --val_record_dir ./dataset/mnist/test.tfrecord

Options:

  • options are same as Cifar10

Cifar100 (19.04.16 added)

python train.py --class_num 100 --image_shape 32 32 3 --stages 4 --channel_count 78 --graph_model ws --graph_param 32 4 0.75 --dropout_rate 0.2 --learning_rate 0.1 --momentum 0.9 --weight_decay 0.0001 --train_set_size 50000 --val_set_size 10000 --batch_size 100 --epochs 100 --checkpoint_dir ./checkpoint --checkpoint_name randwire_cifar100 --train_record_dir ./dataset/cifar100/train.tfrecord --val_record_dir ./dataset/cifar100/test.tfrecord

Options:

  • options are same as Cifar10

Testing

python test.py --class_num --checkpoint_dir ./checkpoint/best --test_record_dir ./dataset/cifar10/test.tfrecord --batch_size 256

Options:

  • --class_num (int) - the number of classes
  • --checkpoint_dir (str) - directory for the checkpoint you want to load and test
  • --test_record_dir (str) - directory for the test dataset
  • --batch_size (int) - batch size for testing

test.py loads network graph and tensors from meta data and evalutes.

Implementation Details

  • Learning rate decreases by multiplying 0.1 in 50% and 75% of entire training step.

  • I made an option init_subsample in my_regime, my_small_regime and small_regime in RandWire.py which do not to use stride 2 for the initial convolutional layer since cifar10 and mnist has low resolution. if you set init_subsample False, then it will use stride 2.

  • While training, it will save the checkpoint with best validation accuracy.

  • While training, it will save tensorboard log for training and validation accuracy and loss in [YOUR_CHECKPOINT_DIRECTORY]/log. You can visualize yourself with tensorboard.

  • I'm currently working on drop connection for regularization and downloading ImageNet dataset to train on my implementation.

  • I added dropout layer after the Relu-Conv-BN triplet unit for regularization. You can set dropout_rate 0.0 to disable it.

  • In train.py, you can use small_regime or regular_regime instead of my_regime and my_small_regime. Both do not use stride 2 in order to prevent subsampling to maintain the spatial information since cifar datasets are not large enough.

  # output logit from NN
  output = RandWire.my_small_regime(images, args.stages, args.channel_count, args.class_num, args.dropout_rate,
                              args.graph_model, args.graph_param, args.checkpoint_dir + '/' + 'graphs', False, training)
  # output = RandWire.small_regime(images, args.stages, args.channel_count, args.class_num, args.dropout_rate,
  #                             args.graph_model, args.graph_param, args.checkpoint_dir + '/' + 'graphs', False,
  #                             training)
  # output = RandWire.regular_regime(images, args.stages, args.channel_count, args.class_num, args.dropout_rate,
  #                             args.graph_model, args.graph_param, args.checkpoint_dir + '/' + 'graphs', training)
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].