Chung-I / Variational Recurrent Autoencoder Tensorflow
Programming Languages
Projects that are alternatives of or similar to Variational Recurrent Autoencoder Tensorflow
Gerating Sentences from a Continuous Space
Tensorflow implementation of Generating Sentences from a Continuous Space.
Prerequisites
- Python packages:
- Python 3.4 or higher
- Tensorflow r0.12
- Numpy
Setting up the environment:
- Clone this repository:
git clone https://github.com/Chung-I/Variational-Recurrent-Autoencoder-Tensorflow.git
- Set up conda environment:
conda create -n vrae python=3.6
conda activate vrae
- Install python package requirements:
pip install -r requirements.txt
Usage
Training:
python vrae.py --model_dir models --do train --new True
Reconstruct:
python vrae.py --model_dir models --do reconstruct --new False --input input.txt --output output.txt
Sample (this script read only the first line of input.txt
, generate num_pts
samples, and write them into output.txt
):
python vrae.py --model_dir models --do sample --new False --input input.txt --output output.txt
Interpolate (this script requires that input.txt
consists of only two sentences; it generate num_pts
interpolations between them, and write those interpolated sentences into output.txt
)::
python vrae.py --model_dir models --do interpolate --new False --input input.txt --output output.txt
model_dir
: The location of the config file config.json
and the checkpoint file.
do
: Accept 4 values: train
, encode_decode
, sample
, or interpolate
.
new
: create models with fresh parameters if set to True
; else read model parameters from checkpoints in model_dir
.
config.json
Hyperparameters are not passed from command prompt like that in tensorflow/models/rnn/translate/translate.py. Instead, vrae.py reads hyperparameters from config.json in model_dir
.
Below are hyperparameters in config.json:
-
model
:-
size
: embedding size, and encoder/decoder state size. -
latent_dim
: latent space size. -
in_vocab_size
: source vocabulary size. -
out_vocab_size
: target vocabulary size. -
data_dir
: path to the corpus. -
num_layers
: number of layers for encoder and decoder. -
use_lstm
: use lstm for encoder and decoder or not. UseBasicLSTMCell
if set toTrue
; elseGRUCell
is used. -
buckets
: A list of pairs of [input size, output size] for each bucket. -
bidirectional
:bidirectional_rnn
is used if set toTrue
. -
probablistic
: variance is set to zero if set toFalse
. -
orthogonal_initializer
:orthogonal_initializer
is used if set toTrue
; elseuniform_unit_scaling_initializer
is used. -
iaf
: inverse autoregressive flow is used if set toTrue
. -
activation
: activation for encoder-to-latent layer and latent-to-decoder layer.-
elu
: exponential linear unit. -
prelu
: parametric linear unit. (default) -
None
: linear.
-
-
-
train
:batch_size
-
beam_size
: beam size for decoding. Warning: beam search is still under implementation.NotImplementedError
would be raised ifbeam_size
is set to be greater than 1. -
learning_rate
: learning rate parameter passed intoAdamOptimizer
. -
steps_per_checkpoint
: save checkpoint everysteps_per_checkpoint
steps. -
anneal
: do KL cost annealing if set toTrue
. -
kl_rate_rise_factor
: KL term weight is increasd by this much everysteps_per_checkpoint
steps. -
max_train_data_size
: Limit on the size of training data (0: no limit). -
feed_previous
: IfTrue
, only the first of decoder_inputs will be used (the "GO" symbol), and all other decoder inputs will be generated by:next = embedding_lookup(embedding, argmax(previous_output))
. In effect, this implements a greedy decoder. It can also be used during training to emulate http://arxiv.org/abs/1506.03099. IfFalse
,decoder_inputs
are used as given (the standard decoder case). -
kl_min
: the minimum information constraint. Should be a non-negative float (where 0 is no constraint). -
max_gradient_norm
: gradients will be clipped to maximally this norm. -
word_dropout_keep_prob
: probability of randomly replacing some fraction of the conditioned-on word tokens with the generic unknown word tokenUNK
. when equal to 0, the decoder sees no input.
-
reconstruct:
feed_previous
word_dropout_keep_prob
-
sample:
feed_previous
word_dropout_keep_prob
-
num_pts
: samplenum_pts
points.
-
interpolate:
feed_previous
word_dropout_keep_prob
-
num_pts
: samplenum_pts
points.
Data
Penn TreeBank corpus is included in the repo. We also provide a Chinese poem corpus, its preprocessed version (set {"model":{"data_dir": "<corpus_dir>"}}
in <model_dir>/config.json
to it), and its pretrained model (set model_dir
to it), all of which can be found here.