akanimax / Pro_gan_pytorch Examples
Programming Languages
Projects that are alternatives of or similar to Pro gan pytorch Examples
pro_gan_pytorch-examples
This repository contains examples trained using
the python package pro-gan-pth
. You can find the github repo for
the project at
github-repository
and the PyPI package at
pypi
There are two examples presented here for LFW dataset and MNIST dataset. Please refer to the following sections for how to train and / or load the provided trained weights for these models.
Prior Setup
Before running any of the following training experiments, please setup
your VirtualEnv
with the required packages for this project. Importantly,
please install the progan package using $ pip install pro-gan-pth
and
your appropriate gpu / cpu version of PyTorch 0.4.0
. Once this
is done, you can proceed with the following experiments.
LFW Experiment
The configuration used for the LFW training experiment can be found in
implementation/configs/lfw.conf
in this repository. The training was
performed using the wgan-gp
loss function.
Examples:
Sample loss plot:
MNIST Experiment
The configuration used for the MNIST training experiment can be found in
implementation/configs/mnist.conf
in this repository. The training was
performed using the lsgan
loss function.
Examples:
Sample loss plot:
How to use:
For running the training script, simply use the following procedure: Running the training script:
$ cd implementation
$ python train_network.py --config=configs/mnist.conf
You can tinker with the configuration for your desired behaviour. This training script also exposes some of the use cases of the package's api.
You can generate the loss plots from the `loss-logs` by using the provided script. The logs get generated while the training is in progress. Generating loss plots:
$ python generate_loss_plots --logdir=training_runs/mnist/losses/ \
--plotdir=training_runs/mnist/losses/loss_plots/
please refer to the following code snippet if you just wish to use the trained model for generating samples: Using trained model:
import torch as th
import pro_gan_pytorch.PRO_GAN as pg
import matplotlib.pyplot as plt
device = th.device("cuda" if th.cuda.is_available()
else "cpu")
gen = pg.Generator(depth=4, latent_size=128,
use_eql=False).to(device)
gen.load_state_dict(
th.load("training_runs/saved_models/GAN_GEN_3.pth")
)
noise = th.randn(1, 128).to(device)
sample_image = gen(noise, detph=3, alpha=1).detach()
plt.imshow(sample_image[0].permute(1, 2, 0) / 2 + 0.5)
plt.show()
The trained weights can be found in the saved_models
directory present in respective training_runs
.
How to use on Google Colab Notebook:
This code can be run on Google Colaboratory using GPU acceleration. Colab offers a free Tesla K80 GPU with up to ~12GB of VRAM. However, the duration of the instance is limited and closes after a certain time. All installed libraries and saved files will be reset in that process. A workaround is to save training results to Google Drive. The packages need to be installed after every instance reset.
Here is a step-by-step instruction on how to run this using Google Colab. ProGAN Colaboratory Notebook
Thanks:
Please feel free to open PRs here if you train on other datasets
using this package.
Best regards,
@akanimax :)