All Projects → akanimax → Pro_gan_pytorch

akanimax / Pro_gan_pytorch

Licence: mit
ProGAN package implemented as an extension of PyTorch nn.Module

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Pro gan pytorch

Pycadl
Python package with source code from the course "Creative Applications of Deep Learning w/ TensorFlow"
Stars: ✭ 356 (-16.24%)
Mutual labels:  gan
Gan Timeline
A timeline showing the development of Generative Adversarial Networks (GAN).
Stars: ✭ 379 (-10.82%)
Mutual labels:  gan
Simgan Captcha
Solve captcha without manually labeling a training set
Stars: ✭ 405 (-4.71%)
Mutual labels:  gan
Sdv
Synthetic Data Generation for tabular, relational and time series data.
Stars: ✭ 360 (-15.29%)
Mutual labels:  gan
Stylegan2 Tensorflow 2.0
StyleGAN 2 in Tensorflow 2.0
Stars: ✭ 370 (-12.94%)
Mutual labels:  gan
Autogan
[ICCV 2019] "AutoGAN: Neural Architecture Search for Generative Adversarial Networks" by Xinyu Gong, Shiyu Chang, Yifan Jiang and Zhangyang Wang
Stars: ✭ 388 (-8.71%)
Mutual labels:  gan
Cat Generator
Generate cat images with neural networks
Stars: ✭ 354 (-16.71%)
Mutual labels:  gan
Wassersteingan.tensorflow
Tensorflow implementation of Wasserstein GAN - arxiv: https://arxiv.org/abs/1701.07875
Stars: ✭ 419 (-1.41%)
Mutual labels:  gan
Mixup
Implementation of the mixup training method
Stars: ✭ 377 (-11.29%)
Mutual labels:  gan
Pytorch Rl
This repository contains model-free deep reinforcement learning algorithms implemented in Pytorch
Stars: ✭ 394 (-7.29%)
Mutual labels:  gan
Pytorch Mnist Celeba Gan Dcgan
Pytorch implementation of Generative Adversarial Networks (GAN) and Deep Convolutional Generative Adversarial Networks (DCGAN) for MNIST and CelebA datasets
Stars: ✭ 363 (-14.59%)
Mutual labels:  gan
Anycost Gan
[CVPR 2021] Anycost GANs for Interactive Image Synthesis and Editing
Stars: ✭ 367 (-13.65%)
Mutual labels:  gan
Sean
SEAN: Image Synthesis with Semantic Region-Adaptive Normalization (CVPR 2020, Oral)
Stars: ✭ 387 (-8.94%)
Mutual labels:  gan
Advanced Tensorflow
Little More Advanced TensorFlow Implementations
Stars: ✭ 364 (-14.35%)
Mutual labels:  gan
Simgan
Implementation of Apple's Learning from Simulated and Unsupervised Images through Adversarial Training
Stars: ✭ 406 (-4.47%)
Mutual labels:  gan
Time Series Prediction
A collection of time series prediction methods: rnn, seq2seq, cnn, wavenet, transformer, unet, n-beats, gan, kalman-filter
Stars: ✭ 351 (-17.41%)
Mutual labels:  gan
Tensorflow Generative Model Collections
Collection of generative models in Tensorflow
Stars: ✭ 3,785 (+790.59%)
Mutual labels:  gan
Deep Learning Resources
由淺入深的深度學習資源 Collection of deep learning materials for everyone
Stars: ✭ 422 (-0.71%)
Mutual labels:  gan
Tensorflow Tutorial
Tensorflow tutorial from basic to hard, 莫烦Python 中文AI教学
Stars: ✭ 4,122 (+869.88%)
Mutual labels:  gan
Igan
Interactive Image Generation via Generative Adversarial Networks
Stars: ✭ 3,845 (+804.71%)
Mutual labels:  gan

pro_gan_pytorch

Package contains implementation of ProGAN.
Paper titled "Progressive growing of GANs for improved Quality, Stability, and Variation".
link -> https://arxiv.org/abs/1710.10196
Trained Examples at -> https://github.com/akanimax/pro_gan_pytorch-examples

⭐️ [New] Pretrained Models:

Please find the pretrained models under the saved_models/ directory at the drive_link

⭐️ [New] Demo:


The repository now includes a latent-space interpolation animation demo under the samples/ directory. Just download all the pretrained weights from the above mentioned drive_link and put them in the samples/ directory alongside the demo.py script. Note that there are a few tweakable parameters at the beginning of the demo.py script so that you can play around with it.

The demo loads up images for random points and then linearly interpolates among them to generate smooth animation. You need to have a good GPU (atleast GTX 1070) to see formidable FPS in the demo. The demo however can be optimized to do parallel generation of the images (It is completely sequential currently).

In order to load weights in the Generator, the process is the standard process for PyTorch model loading.

import torch as th
from pro_gan_pytorch import PRO_GAN as pg

device = th.device("cuda" if th.cuda.is_available() else "cpu")

gen = th.nn.DataParallel(pg.Generator(depth=9))
gen.load_state_dict(th.load("GAN_GEN_SHADOW_8.pth", map_location=str(device)))

Notes for the Above code:

  1. Create a new generator module using pg (depth = 9 means the generating resolution will be 1024 x 1024).
  2. Note that DataParallel is required here because I have trained the models on Multiple GPUs.
    you wouldn't need to wrap the Generator into a DataParallel if you train on CPU.
    Which I don't think is feasible for a GAN in general (:D).
  3. You can simply load the weights into the gen as it is implemented as a PyTorch Module.
  4. map_location arg takes care of Device mismatch. As in, if you trained on GPU but inferring on CPU.
  5. Also note that we need to use the GAN_GEN_SHADOW_8.pth model and not GAN_GEN_8.pth.
    The shadow model contains the Exponential Moving Averaged weights (stable weights).

Exemplar Samples :)

Training gif (fixed latent points):


Generated Samples:



Other links

medium blog -> https://medium.com/@animeshsk3/the-unprecedented-effectiveness-of-progressive-growing-of-gans-37475c88afa3
Full training video -> https://www.youtube.com/watch?v=lzTm6Lq76Mo

Steps to use:

1.) Install your appropriate version of PyTorch. The torch dependency in this package uses the most basic "cpu" version. follow instructions on http://pytorch.org to install the "gpu" version of PyTorch.

2.) Install this package using pip:

$ workon [your virtual environment]
$ pip install pro-gan-pth

3.) In your code:

import pro_gan_pytorch.PRO_GAN as pg

Use the modules pg.Generator, pg.Discriminator and pg.ProGAN. Mostly, you'll only need the ProGAN module for training. For inference, you will probably need the pg.Generator.

4.) Example Code for CIFAR-10 dataset:

import torch as th
import torchvision as tv
import pro_gan_pytorch.PRO_GAN as pg

# select the device to be used for training
device = th.device("cuda" if th.cuda.is_available() else "cpu")
data_path = "cifar-10/"

def setup_data(download=False):
    """
    setup the CIFAR-10 dataset for training the CNN
    :param batch_size: batch_size for sgd
    :param num_workers: num_readers for data reading
    :param download: Boolean for whether to download the data
    :return: classes, trainloader, testloader => training and testing data loaders
    """
    # data setup:
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')

    transforms = tv.transforms.ToTensor()

    trainset = tv.datasets.CIFAR10(root=data_path,
                                   transform=transforms,
                                   download=download)

    testset = tv.datasets.CIFAR10(root=data_path,
                                  transform=transforms, train=False,
                                  download=False)

    return classes, trainset, testset


if __name__ == '__main__':

    # some parameters:
    depth = 4
    # hyper-parameters per depth (resolution)
    num_epochs = [10, 20, 20, 20]
    fade_ins = [50, 50, 50, 50]
    batch_sizes = [128, 128, 128, 128]
    latent_size = 128

    # get the data. Ignore the test data and their classes
    _, dataset, _ = setup_data(download=True)

    # ======================================================================
    # This line creates the PRO-GAN
    # ======================================================================
    pro_gan = pg.ConditionalProGAN(num_classes=10, depth=depth, 
                                   latent_size=latent_size, device=device)
    # ======================================================================

    # ======================================================================
    # This line trains the PRO-GAN
    # ======================================================================
    pro_gan.train(
        dataset=dataset,
        epochs=num_epochs,
        fade_in_percentage=fade_ins,
        batch_sizes=batch_sizes
    )
    # ======================================================================  

Thanks

Please feel free to open PRs / issues / suggestions here if you train on other datasets using this architecture.

Best regards,
@akanimax :)

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].