All Projects → eyalzk → Sketch_rnn_keras

eyalzk / Sketch_rnn_keras

Licence: mit
Keras implementation of Sketch RNN

Projects that are alternatives of or similar to Sketch rnn keras

Nfft
Lightweight non-uniform Fast Fourier Transform in Python
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Smilecnn
Smile detection with a deep convolutional neural net, with Keras.
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Cnn Audio Denoiser
Tensorflow 2.0 implementation of the paper: A Fully Convolutional Neural Network for Speech Enhancement
Stars: ✭ 138 (+0%)
Mutual labels:  jupyter-notebook
Text Analytics
Unstructured Data Analysis (Graduate) @Korea University
Stars: ✭ 138 (+0%)
Mutual labels:  jupyter-notebook
Dat210x
Programming with Python for Data Science Microsoft
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Copy Paste Aug
Copy-paste augmentation for segmentation and detection tasks
Stars: ✭ 132 (-4.35%)
Mutual labels:  jupyter-notebook
Dask Docker
Docker images for dask
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Deep Reinforcement Stock Trading
A light-weight deep reinforcement learning framework for portfolio management. This project explores the possibility of applying deep reinforcement learning algorithms to stock trading in a highly modular and scalable framework.
Stars: ✭ 136 (-1.45%)
Mutual labels:  jupyter-notebook
End To End Generative Dialogue
A neural conversation model
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Rapids Single Cell Examples
Examples of single-cell genomic analysis accelerated with RAPIDS
Stars: ✭ 138 (+0%)
Mutual labels:  jupyter-notebook
Youtube Like Predictor
YouTube Like Count Predictions using Machine Learning
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Robustautoencoder
A combination of Autoencoder and Robust PCA
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Fetching Financial Data
Fetching financial data for technical & fundamental analysis and algorithmic trading from a variety of python packages and sources.
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Drl
Deep RL Algorithms implemented for UC Berkeley's CS 294-112: Deep Reinforcement Learning
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Easy slam tutorial
首个中文的简单从零开始实现视觉SLAM理论与实践教程,使用Python实现。包括:ORB特征点提取,对极几何,视觉里程计后端优化,实时三维重建地图。A easy SLAM practical tutorial (Python).图像处理、otsu二值化。更多其他教程我的CSDN博客
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Zhihu Spider
一个获取知乎用户主页信息的多线程Python爬虫程序。
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Qutip Notebooks
A collection of IPython notebooks using QuTiP: examples, tutorials, development test, etc.
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Minifold
MiniFold: Deep Learning for Protein Structure Prediction inspired by DeepMind AlphaFold algorithm
Stars: ✭ 136 (-1.45%)
Mutual labels:  jupyter-notebook
Dab And Tpose Controlled Lights
Control your lights with dab and t-pose, duh
Stars: ✭ 137 (-0.72%)
Mutual labels:  jupyter-notebook
Glasses
High-quality Neural Networks for Computer Vision 😎
Stars: ✭ 138 (+0%)
Mutual labels:  jupyter-notebook

A Keras Implementation of Sketch-RNN

In this repo there's a Kares implementation of the Sketch-RNN algorithm,
as described in the paper A Neural Representation of Sketch Drawings by David Ha and Douglas Eck (Google AI).

The implementation is ported from the official Tensorflow implementation that was released under project Magenta by the authors.

Overview

Sketch-RNN consists of a Sequence to Sequence Variational Autoencoder (Seq2SeqVAE), which is able to encode a series of pen strokes (a sketch) into a latent space, using a bidirectional LSTM as the encoder. The latent representation can then be decoded back into a series of strokes.
The model is trained to reconstruct the original stroke sequences while maintaining a normal distribution across latent space elements. Since encoding is performed stochastically, and so is the sampling mechanism of the decoder, the reconstructed sketches are always different.
This allows to let a trained model draw new and unique sketches that it has not seen before. Designing the model as a variational autoencoder also allows to perform latent space manipulations to get interesting interpolations between different sketches.

There's no need to elaborate on the specifics of the algorithm, since many great resources exist for this end.
I recommend David Ha's blog post Teaching Machines to Draw.

Implementation Details

You can find in this repo some useful solutions for common pitfalls when porting from TF to Keras (and writing Keras in general), for example:

  • Injecting values to intermediate tensors and predicting the corresponding values of other tensors by building sub-models
  • Using custom generators to wrap data loader classes
  • Using an auxiliary loss term that uses intermediate layers' outputs rather than the model's predictions
  • Using a CuDNN LSTM layer, while allowing inference on CPU
  • Resuming a training process from a checkpoint in the case that custom callbacks are used with dynamic internal variables

Dependencies

Tested in the following environment:

  • Keras 2.2.4 (Tensorflow 1.11 backend)
  • Python 3.5
  • Windows OS

Hopefully, soon I will update with minimum requirements

Usage

Training

To train a model, you need a dataset in the appropriate format. You can download one of many prepared sketches datasets that were released by Google. Simply download one or more .npz files and save them in the same directory (recommended to use a datasets directory within the project's main directory).

Example usage:

python seq2seqVAE_train --data_dir=datasets --data_set=cat --experiment_dir=\sketch_rnn\experiments

Currently, configurable hyperparameters can only be modified by changing their default values in seq2seqVAE.py.
I might add an option to configure them via command line in the future.

You can also resume the training from a saved checkpoint by supplying the --checkpoint=[file path] and a --initial_epoch=[epoch from which you are restarting] arguments.

The full list of configurable parameters:

        # Experiment Params:
        'is_training': True,           # train mode (relevant only for accelerated LSTM mode)
        'data_set': 'cat',             # datasets to train on
        'epochs': 50,                  # how many times to go over the full train set (on average, since batches are drawn randomly)
        'save_every': None,            # Batches between checkpoints creation and validation set evaluation. Once an epoch if None.
        'batch_size': 100,             # Minibatch size. Recommend leaving at 100.
        'accelerate_LSTM': False,      # Flag for using CuDNNLSTM layer, gpu + tf backend only
        # Loss Params:    
        'optimizer': 'adam',           # adam or sgd
        'learning_rate': 0.001,    
        'decay_rate': 0.9999,          # Learning rate decay per minibatch.
        'min_learning_rate': .00001,   # Minimum learning rate.
        'kl_tolerance': 0.2,           # Level of KL loss at which to stop optimizing for KL.
        'kl_weight': 0.5,              # KL weight of loss equation. Recommend 0.5 or 1.0.
        'kl_weight_start': 0.01,       # KL start weight when annealing.
        'kl_decay_rate': 0.99995,      # KL annealing decay rate per minibatch.
        'grad_clip': 1.0,              # Gradient clipping. Recommend leaving at 1.0.
        # Architecture Params:
        'z_size': 128,                 # Size of latent vector z. Recommended 32, 64 or 128.
        'enc_rnn_size': 256,           # Units in encoder RNN.
        'dec_rnn_size': 512,           # Units in decoder RNN.
        'use_recurrent_dropout': True, # Dropout with memory loss. Recommended
        'recurrent_dropout_prob': 0.9, # Probability of recurrent dropout keep.
        'num_mixture': 20,             # Number of mixtures in Gaussian mixture model.
        # Data pre-processing Params:
        'random_scale_factor': 0.15,   # Random scaling data augmentation proportion.
        'augment_stroke_prob': 0.10    # Point dropping augmentation proportion.

Using a trained model to draw

cat_interp

In the notebook Skecth_RNN_Keras.ipynb you can supply a path to a trained model and a dataset and explore what the model has learned. There are examples of encoding and decoding of sketches, interpolating in latent space, sampling under different temperature values etc.
You can also load models trained on multiple data-sets and generate nifty interpolations such as these guitar-cats! guitar_cat_interp guitar_cat_interp

Examples in this notebook use models trained on my laptop's GPU:

  • Cat model : 50 epochs
  • Cat + Guitar model: 15 epochs

Both pre trained models are included in this repo.

References

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].