All Projects → alexlee-gk → Video_prediction

alexlee-gk / Video_prediction

Licence: mit
Stochastic Adversarial Video Prediction

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Video prediction

Deep Learning With Python
Example projects I completed to understand Deep Learning techniques with Tensorflow. Please note that I do no longer maintain this repository.
Stars: ✭ 134 (-45.75%)
Mutual labels:  gan, generative-adversarial-network, vae, variational-autoencoder
Pytorch Rl
This repository contains model-free deep reinforcement learning algorithms implemented in Pytorch
Stars: ✭ 394 (+59.51%)
Mutual labels:  gan, generative-adversarial-network, vae, variational-autoencoder
precision-recall-distributions
Assessing Generative Models via Precision and Recall (official repository)
Stars: ✭ 80 (-67.61%)
Mutual labels:  generative-adversarial-network, vae, variational-autoencoder
Generative Models
Annotated, understandable, and visually interpretable PyTorch implementations of: VAE, BIRVAE, NSGAN, MMGAN, WGAN, WGANGP, LSGAN, DRAGAN, BEGAN, RaGAN, InfoGAN, fGAN, FisherGAN
Stars: ✭ 438 (+77.33%)
Mutual labels:  gan, generative-adversarial-network, vae
Tensorflow Generative Model Collections
Collection of generative models in Tensorflow
Stars: ✭ 3,785 (+1432.39%)
Mutual labels:  gan, vae, variational-autoencoder
Focal Frequency Loss
Focal Frequency Loss for Generative Models
Stars: ✭ 141 (-42.91%)
Mutual labels:  gan, generative-adversarial-network, variational-autoencoder
Arbitrary Text To Image Papers
A collection of arbitrary text to image papers with code (constantly updating)
Stars: ✭ 196 (-20.65%)
Mutual labels:  gan, generative-adversarial-network
S Vae Tf
Tensorflow implementation of Hyperspherical Variational Auto-Encoders
Stars: ✭ 198 (-19.84%)
Mutual labels:  vae, variational-autoencoder
Iseebetter
iSeeBetter: Spatio-Temporal Video Super Resolution using Recurrent-Generative Back-Projection Networks | Python3 | PyTorch | GANs | CNNs | ResNets | RNNs | Published in Springer Journal of Computational Visual Media, September 2020, Tsinghua University Press
Stars: ✭ 202 (-18.22%)
Mutual labels:  gan, generative-adversarial-network
Gif
GIF is a photorealistic generative face model with explicit 3D geometric and photometric control.
Stars: ✭ 233 (-5.67%)
Mutual labels:  gan, generative-adversarial-network
Pytorch Generative Model Collections
Collection of generative models in Pytorch version.
Stars: ✭ 2,296 (+829.55%)
Mutual labels:  gan, generative-adversarial-network
Triple Gan
See Triple-GAN-V2 in PyTorch: https://github.com/taufikxu/Triple-GAN
Stars: ✭ 203 (-17.81%)
Mutual labels:  gan, generative-adversarial-network
The Gan World
Everything about Generative Adversarial Networks
Stars: ✭ 243 (-1.62%)
Mutual labels:  gan, generative-adversarial-network
Pytorch Cyclegan And Pix2pix
Image-to-Image Translation in PyTorch
Stars: ✭ 16,477 (+6570.85%)
Mutual labels:  gan, generative-adversarial-network
Freezed
Freeze the Discriminator: a Simple Baseline for Fine-Tuning GANs (CVPRW 2020)
Stars: ✭ 195 (-21.05%)
Mutual labels:  gan, generative-adversarial-network
Cada Vae Pytorch
Official implementation of the paper "Generalized Zero- and Few-Shot Learning via Aligned Variational Autoencoders" (CVPR 2019)
Stars: ✭ 198 (-19.84%)
Mutual labels:  vae, variational-autoencoder
Creative Adversarial Networks
(WIP) Implementation of Creative Adversarial Networks https://arxiv.org/pdf/1706.07068.pdf
Stars: ✭ 193 (-21.86%)
Mutual labels:  gan, generative-adversarial-network
Gan Sandbox
Vanilla GAN implemented on top of keras/tensorflow enabling rapid experimentation & research. Branches correspond to implementations of stable GAN variations (i.e. ACGan, InfoGAN) and other promising variations of GANs like conditional and Wasserstein.
Stars: ✭ 210 (-14.98%)
Mutual labels:  gan, generative-adversarial-network
Vae Cvae Mnist
Variational Autoencoder and Conditional Variational Autoencoder on MNIST in PyTorch
Stars: ✭ 229 (-7.29%)
Mutual labels:  vae, variational-autoencoder
Ranksrgan
ICCV 2019 (oral) RankSRGAN: Generative Adversarial Networks with Ranker for Image Super-Resolution. PyTorch implementation
Stars: ✭ 213 (-13.77%)
Mutual labels:  gan, generative-adversarial-network

Stochastic Adversarial Video Prediction

[Project Page] [Paper]

TensorFlow implementation for stochastic adversarial video prediction. Given a sequence of initial frames, our model is able to predict future frames of various possible futures. For example, in the next two sequences, we show the ground truth sequence on the left and random predictions of our model on the right. Predicted frames are indicated by the yellow bar at the bottom. For more examples, visit the project page.

Stochastic Adversarial Video Prediction,
Alex X. Lee, Richard Zhang, Frederik Ebert, Pieter Abbeel, Chelsea Finn, Sergey Levine.
arXiv preprint arXiv:1804.01523, 2018.

An alternative implementation of SAVP is available in the Tensor2Tensor library.

Getting Started

Prerequisites

  • Linux or macOS
  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Installation

  • Clone this repo:
git clone -b master --single-branch https://github.com/alexlee-gk/video_prediction.git
cd video_prediction
  • Install TensorFlow >= 1.9 and dependencies from http://tensorflow.org/
  • Install ffmpeg (optional, used to generate GIFs for visualization, e.g. in TensorBoard)
  • Install other dependencies
pip install -r requirements.txt

Miscellaneous installation considerations

  • In python >= 3.6, make sure to add the root directory to the PYTHONPATH, e.g. export PYTHONPATH=path/to/video_prediction.
  • For the best speed and experimental results, we recommend using cudnn version 7.3.0.29 and any tensorflow version >= 1.9 and <= 1.12. The final training loss is worse when using cudnn versions 7.3.1.20 or 7.4.1.5, compared to when using versions 7.3.0.29 and below.
  • In macOS, make sure that bash >= 4.0 is used (needed for associative arrays in download_model.sh script).

Use a Pre-trained Model

  • Download and preprocess a dataset (e.g. bair):
bash data/download_and_preprocess_dataset.sh bair
  • Download a pre-trained model (e.g. ours_savp) for the action-free version of that dataset (i.e. bair_action_free):
bash pretrained_models/download_model.sh bair_action_free ours_savp
  • Sample predictions from the model:
CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \
  --dataset_hparams sequence_length=30 \
  --checkpoint pretrained_models/bair_action_free/ours_savp \
  --mode test \
  --results_dir results_test_samples/bair_action_free
  • The predictions are saved as images and GIFs in results_test_samples/bair_action_free/ours_savp.
  • Evaluate predictions from the model using full-reference metrics:
CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/bair \
  --dataset_hparams sequence_length=30 \
  --checkpoint pretrained_models/bair_action_free/ours_savp \
  --mode test \
  --results_dir results_test/bair_action_free

Model Training

  • To train a model, download and preprocess a dataset (e.g. bair):
bash data/download_and_preprocess_dataset.sh bair
  • Train a model (e.g. our SAVP model on the BAIR action-free robot pushing dataset):
CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair \
  --model savp --model_hparams_dict hparams/bair_action_free/ours_savp/model_hparams.json \
  --output_dir logs/bair_action_free/ours_savp
  • To view training and validation information (e.g. loss plots, GIFs of predictions), run tensorboard --logdir logs/bair_action_free --port 6006 and open http://localhost:6006.
    • Summaries corresponding to the training and validation set are named the same except that the tags of the latter end in "_1".
    • Summaries corresponding to the validation set with sequences that are longer than the ones used in training end in "_2", if applicable (i.e. if the dataset's long_sequence_length differs from sequence_length).
    • Summaries of the metrics over prediction steps are shown as 2D plots in the repurposed PR curves section. To see them, tensorboard needs to be built from source after commenting out two lines from their source code (see tensorflow/tensorboard#1110).
    • Summaries with names starting with "eval_" correspond to the best/average/worst metrics/images out of 100 samples for the stochastic models (as in the paper). The ones starting with "accum_eval_" are the same except that they where computed over (roughly) the whole validation set, as opposed to only a single minibatch of the validation set.
  • For multi-GPU training, set CUDA_VISIBLE_DEVICES to a comma-separated list of devices, e.g. CUDA_VISIBLE_DEVICES=0,1,2,3. To use the CPU, set CUDA_VISIBLE_DEVICES="".
  • See more training details for other datasets and models in scripts/train_all.sh.

Datasets

Download the datasets using the following script. These datasets are collected by other researchers. Please cite their papers if you use the data.

  • Download and preprocess the dataset.
bash data/download_and_preprocess_dataset.sh dataset_name

The dataset_name should be one of the following:

To use a different dataset, preprocess it into TFRecords files and define a class for it. See kth_dataset.py for an example where the original dataset is given as videos.

Note: the bair dataset is used for both the action-free and action-conditioned experiments. Set the hyperparameter use_state=True to use the action-conditioned version of the dataset.

Models

  • Download the pre-trained models using the following script.
bash pretrained_models/download_model.sh dataset_name model_name

The dataset_name should be one of the following: bair_action_free, kth, or bair. The model_name should be one of the available pre-trained models:

  • ours_savp: our complete model, trained with variational and adversarial losses. Also referred to as ours_vae_gan.

The following are ablations of our model:

  • ours_gan: trained with L1 and adversarial loss, with latent variables sampled from the prior at training time.
  • ours_vae: trained with L1 and KL loss.
  • ours_deterministic: trained with L1 loss, with no stochastic latent variables.

See pretrained_models/download_model.sh for a complete list of available pre-trained models.

Model and Training Hyperparameters

The implementation is designed such that each video prediction model defines its architecture and training procedure, and include reasonable hyperparameters as defaults. Still, a few of the hyperparameters should be overriden for each variant of dataset and model. The hyperparameters used in our experiments are provided in hparams as JSON files, and they can be passed onto the training script with the --model_hparams_dict flag.

Citation

If you find this useful for your research, please use the following.

@article{lee2018savp,
  title={Stochastic Adversarial Video Prediction},
  author={Alex X. Lee and Richard Zhang and Frederik Ebert and Pieter Abbeel and Chelsea Finn and Sergey Levine},
  journal={arXiv preprint arXiv:1804.01523},
  year={2018}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].