All Projects → richardwth → MMD-GAN

richardwth / MMD-GAN

Licence: Apache-2.0 license
Improving MMD-GAN training with repulsive loss function

Programming Languages

python
139335 projects - #7 most used programming language
Jupyter Notebook
11667 projects

Projects that are alternatives of or similar to MMD-GAN

Semantic image inpainting
Semantic Image Inpainting
Stars: ✭ 140 (+70.73%)
Mutual labels:  generative-adversarial-network, dcgan, generative-model
coursera-gan-specialization
Programming assignments and quizzes from all courses within the GANs specialization offered by deeplearning.ai
Stars: ✭ 277 (+237.8%)
Mutual labels:  generative-adversarial-network, dcgan, generative-model
Generative models tutorial with demo
Generative Models Tutorial with Demo: Bayesian Classifier Sampling, Variational Auto Encoder (VAE), Generative Adversial Networks (GANs), Popular GANs Architectures, Auto-Regressive Models, Important Generative Model Papers, Courses, etc..
Stars: ✭ 276 (+236.59%)
Mutual labels:  generative-adversarial-network, dcgan, generative-model
Image generator
DCGAN image generator 🖼️.
Stars: ✭ 173 (+110.98%)
Mutual labels:  generative-adversarial-network, dcgan
DCGAN-CIFAR10
A implementation of DCGAN (Deep Convolutional Generative Adversarial Networks) for CIFAR10 image
Stars: ✭ 18 (-78.05%)
Mutual labels:  generative-adversarial-network, dcgan
Tensorflow Mnist Gan Dcgan
Tensorflow implementation of Generative Adversarial Networks (GAN) and Deep Convolutional Generative Adversarial Netwokrs for MNIST dataset.
Stars: ✭ 163 (+98.78%)
Mutual labels:  generative-adversarial-network, dcgan
Generative adversarial networks 101
Keras implementations of Generative Adversarial Networks. GANs, DCGAN, CGAN, CCGAN, WGAN and LSGAN models with MNIST and CIFAR-10 datasets.
Stars: ✭ 138 (+68.29%)
Mutual labels:  generative-adversarial-network, dcgan
Triple Gan
See Triple-GAN-V2 in PyTorch: https://github.com/taufikxu/Triple-GAN
Stars: ✭ 203 (+147.56%)
Mutual labels:  generative-adversarial-network, generative-model
Dragan
A stable algorithm for GAN training
Stars: ✭ 189 (+130.49%)
Mutual labels:  generative-adversarial-network, generative-model
Anogan Tf
Unofficial Tensorflow Implementation of AnoGAN (Anomaly GAN)
Stars: ✭ 218 (+165.85%)
Mutual labels:  generative-adversarial-network, dcgan
Deeplearning.ai-GAN-Specialization-Generative-Adversarial-Networks
This repository contains my full work and notes on Deeplearning.ai GAN Specialization (Generative Adversarial Networks)
Stars: ✭ 59 (-28.05%)
Mutual labels:  discriminator, generative-adversarial-network
Stylegan2 Pytorch
Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
Stars: ✭ 2,656 (+3139.02%)
Mutual labels:  generative-adversarial-network, generative-model
Conditional Gan
Anime Generation
Stars: ✭ 141 (+71.95%)
Mutual labels:  generative-adversarial-network, generative-model
Dcgan wgan wgan Gp lsgan sngan rsgan began acgan pggan tensorflow
Implementation of some different variants of GANs by tensorflow, Train the GAN in Google Cloud Colab, DCGAN, WGAN, WGAN-GP, LSGAN, SNGAN, RSGAN, RaSGAN, BEGAN, ACGAN, PGGAN, pix2pix, BigGAN
Stars: ✭ 166 (+102.44%)
Mutual labels:  generative-adversarial-network, dcgan
Sgan
Stacked Generative Adversarial Networks
Stars: ✭ 240 (+192.68%)
Mutual labels:  generative-adversarial-network, generative-model
Pytorch-conditional-GANs
Implementation of Conditional Generative Adversarial Networks in PyTorch
Stars: ✭ 91 (+10.98%)
Mutual labels:  generative-adversarial-network, dcgan
Neuralnetworks.thought Experiments
Observations and notes to understand the workings of neural network models and other thought experiments using Tensorflow
Stars: ✭ 199 (+142.68%)
Mutual labels:  generative-adversarial-network, generative-model
pytorch-gans
PyTorch implementation of GANs (Generative Adversarial Networks). DCGAN, Pix2Pix, CycleGAN, SRGAN
Stars: ✭ 21 (-74.39%)
Mutual labels:  generative-adversarial-network, dcgan
Deep Learning With Python
Example projects I completed to understand Deep Learning techniques with Tensorflow. Please note that I do no longer maintain this repository.
Stars: ✭ 134 (+63.41%)
Mutual labels:  generative-adversarial-network, dcgan
Gesturegan
[ACM MM 2018 Oral] GestureGAN for Hand Gesture-to-Gesture Translation in the Wild
Stars: ✭ 136 (+65.85%)
Mutual labels:  generative-adversarial-network, generative-model

MMD-GAN with Repulsive Loss Function

GAN: generative adversarial nets; MMD: maximum mean discrepancy; TF: TensorFlow

This repository contains codes for MMD-GAN and the repulsive loss proposed in ICLR paper [1]:
Wei Wang, Yuan Sun, Saman Halgamuge. Improving MMD-GAN Training with Repulsive Loss Function. ICLR 2019. URL: https://openreview.net/forum?id=HygjqjR9Km.

About the code

The code defines the neural network architecture as dictionaries and strings to ease test of different models. It also contains many other models I have tried, so sorry if you find it a little bit confusing.

The structure of code:

  1. DeepLearning/my_sngan/SNGan defines how a general GAN model is trained and evaluated.
  2. GeneralTools contains various tools:
    1. graph_func contains functions to run a model graph and metrics for evaluating generative models (Line 1595).
    2. input_func contains functions to handle datasets and input pipeline.
    3. layer_func contains functions to convert network architecture dictionary to operations
    4. math_func defines various mathematical operations. You may find spectral normalization at Line 397, loss functions for GAN at Line 2088, repulsive loss at Line 2505, repulsive with bounded kernel (referred to as rmb) at Line 2530.
    5. misc_fun contains FLAGs for the code.
  3. my_test_ contain the specific model architectures and hyperparameters.

Running the tests

  1. Modify GeneralTools/misc_func accordingly;
  2. Read Data/ReadMe.md; download and prepare the datasets;
  3. Run my_test_ with proper hyperparameters.

About the algorithms

Here we introduce the algorithms and tricks.

Proposed Methods

The paper [1] proposed three methods:

  1. Repulsive loss

equation

equation

where equation - real samples, equation - generated samples, equation - kernel formed by the discriminator equation and kernel equation. The discriminator loss of previous MMD-GAN [2], or what we called attractive loss, is equation.

Below is an illustration of the effects of MMD losses on free R(eal) and G(enerated) particles (code in Figures folder). The particles stand for discriminator outputs of samples, but, for illustration purpose, we allow them to move freely. These GIFs extend the Figure 1 of paper [1].

mmd_d_att mmd_d_rep
paired with paired with
mmd_g_att mmd_g_rep

In the first row, we randomly initialized the particles, and applied or for 600 steps. The velocity of each particle is . In the second row, we obtained the particle positions at the 450th step of the first row and applied for another 600 steps with velocity . The blue and orange arrows stand for the gradients of attractive and repulsive components of MMD losses respectively. In summary, these GIFs indicate how MMD losses may move the free particles. Of course, the actual case of MMD-GAN is much more complex as we update the model parameters instead of output scores directly and both networks are updated at each step.

We argue that may cause opposite gradients from attractive and repulsive components of both and during training, and thus slow down the training process. Note this is different from the end-stage training when the gradients should be opposite and cancelled out to reach 0. Another way of interpretation is that, by minimizing , the discriminator maximizes the similarity between the outputs of real samples, which results in D focusing on the similarities among real images and possibly ignoring the fine details that separate them. The repulsive loss actively learns such fine details to make real sample outputs repel each other.

  1. Bounded kernel (used only in equation)

equation

equation

The gradient of Gaussian kernel is near 0 when the input distance is too small or large. The bounded kernel avoids kernel saturation by truncating the two tails of distance distribution, an idea inspired by the hinge loss. This prevents the discriminator from becoming too confident.

  1. Power iteration for convolution (used in spectral normalization)

At last, we proposed a method to calculate the spectral norm of convolution kernel. At iteration t, for convolution kernel equation, do equation, equation, and equation. The spectral norm is estimated as equation.

Practical Tricks and Issues

We recommend using the following tricks.

  1. Spectral normalization, initially proposed in [3]. The idea is, at each layer, to use equation for convolution/dense multiplication. Here we multiply the signal with a constant after each spectral normalization to compensate for the decrease of signal norm at each layer. In the main text of paper [1], we used empirically. In Appendix C.3 of paper [1], we tested a variety of values.
  2. Two time-scale update rule (TTUR) [4]. The idea is to use different learning rates for the generator and discriminator.

Unlike the case of Wasserstein GAN, we do not encourage using the repulsive loss for discriminator or the MMD loss for generator to indicate the progress of training. You may find that, during the training process,

  • both and may be close to 0 initially; this is because both G and D are weak.
  • may gradually increase during training; this is because it becomes harder for G to generate high quality samples and fool D (and G may not have the capacity to do so).

For balanced and capable G and D, we would expect both and to stay close to 0 during the whole training process and any kernel (i.e., , and ) to be away from 0 or 1 and stay in the middle (e.g., 0.6).

In some cases, you may find training using the repulsive loss diverges. Do not panic. It may be that the learning rate is not suitable. Please try other learning rate or the bounded kernel.

Final Comments

Thank you for reading!

Please feel free to leave comments if things do not work or suddenly work, or if exploring my code ruins your day. :)

Reference

[1] Wei Wang, Yuan Sun, Saman Halgamuge. Improving MMD-GAN Training with Repulsive Loss Function. ICLR 2019. URL: https://openreview.net/forum?id=HygjqjR9Km.
[2] Chun-Liang Li, Wei-Cheng Chang, Yu Cheng, Yiming Yang, and Barnabas Poczos. MMD GAN: Towards deeper understanding of moment matching network. In NeurIPS, 2017. [3] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, and Yuichi Yoshida. Spectral normalization for generative adversarial networks. In ICLR, 2018.
[4] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. GANs Trained by a Two Time-Scale Update Rule Converge to a Nash Equilibrium. In NeurIPS, 2017.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].