All Projects → TheShadow29 → infnet-spen

TheShadow29 / infnet-spen

Licence: Apache-2.0 License
TensorFlow implementation [ICLR 18] "Learning Approximate Inference Networks for Structured Prediction"

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to infnet-spen

PerceptualGAN
Pytorch implementation of Image Manipulation with Perceptual Discriminators paper
Stars: ✭ 119 (+296.67%)
Mutual labels:  gan
Deep learning Coloring-Anime-image-and-satellite-image-house-damge-level-colorized
No description or website provided.
Stars: ✭ 16 (-46.67%)
Mutual labels:  gan
pytorch gan
Spectral Normalization and Projection Discriminator
Stars: ✭ 16 (-46.67%)
Mutual labels:  gan
GANs-Keras
GANs Implementations in Keras
Stars: ✭ 24 (-20%)
Mutual labels:  gan
videoMultiGAN
End to End learning for Video Generation from Text
Stars: ✭ 53 (+76.67%)
Mutual labels:  gan
CariMe-pytorch
Unpaired Caricature Generation with Multiple Exaggerations (TMM 2021)
Stars: ✭ 33 (+10%)
Mutual labels:  gan
unrolled-gans
PyTorch Implementation of Unrolled Generative Adversarial Networks
Stars: ✭ 35 (+16.67%)
Mutual labels:  gan
domain-adaptation-capls
Unsupervised Domain Adaptation via Structured Prediction Based Selective Pseudo-Labeling
Stars: ✭ 43 (+43.33%)
Mutual labels:  structured-prediction
Child-Face-Generation
Deep Convolutional Conditional GAN and Supervised CNN for generating children's faces given parents' faces
Stars: ✭ 26 (-13.33%)
Mutual labels:  gan
GAN-Project-2018
GAN in Tensorflow to be run via Linux command line
Stars: ✭ 21 (-30%)
Mutual labels:  gan
catgan pytorch
Unsupervised and Semi-supervised Learning with Categorical Generative Adversarial Networks
Stars: ✭ 50 (+66.67%)
Mutual labels:  gan
Adventures-with-GANS
Showcasing various fun adventures with GANs
Stars: ✭ 13 (-56.67%)
Mutual labels:  gan
anime2clothing
Pytorch official implementation of Anime to Real Clothing: Cosplay Costume Generation via Image-to-Image Translation.
Stars: ✭ 65 (+116.67%)
Mutual labels:  gan
GAN-RNN Timeseries-imputation
Recurrent GAN for imputation of time series data. Implemented in TensorFlow 2 on Wikipedia Web Traffic Forecast dataset from Kaggle.
Stars: ✭ 107 (+256.67%)
Mutual labels:  gan
TET-GAN
[AAAI 2019] TET-GAN: Text Effects Transfer via Stylization and Destylization
Stars: ✭ 74 (+146.67%)
Mutual labels:  gan
wgan-gp
Pytorch implementation of Wasserstein GANs with Gradient Penalty
Stars: ✭ 161 (+436.67%)
Mutual labels:  gan
Computer-Vision
implemented some computer vision problems
Stars: ✭ 25 (-16.67%)
Mutual labels:  gan
metrics
IS, FID score Pytorch and TF implementation, TF implementation is a wrapper of the official ones.
Stars: ✭ 91 (+203.33%)
Mutual labels:  gan
StyleGANCpp
Unofficial implementation of StyleGAN's generator
Stars: ✭ 25 (-16.67%)
Mutual labels:  gan
srgan
Pytorch implementation of "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
Stars: ✭ 39 (+30%)
Mutual labels:  gan

Learning Approximate Inference Networks for Structured Prediction

Initial Setup

  1. First prepare the dataset figment. http://cistern.cis.lmu.de/figment/

  2. Download entity dataset, entity embeddings (around 2gb) into data/figment. Be sure to unzip entity dataset.

  3. python data/preprocess_figment.py

  4. Prepare the bibtex dataset. http://mulan.sourceforge.net/datasets.html

  5. python data/preprocess_bibtex.py

  6. Download the bookmarks dataset from here.

  7. Place it in data/bookmarks. There is no need to run any proprocessing script.

Running

  1. python -m mains.infnet --config configs/figment.json
  2. python -m mains.infnet --config configs/bibtex.json
  3. python -m mains.infnet --config configs/bookmarks.json

Code Description

  1. base/ : Contains the base model and base trainer. The model and trainer are inherited from here.
  2. configs/ : Contains the configuration files stored in json format. All hyper-parameters are stored here.
  3. data/: Contains scripts to process the data files and store them into pickle format
  4. data_loader/: Contains the class DataGenerator which is used to get data from the pipeline. Since most of our models are small, a naive implementation was fine. In case of bigger datasets, it might be worth looking into the tensorflow dataset api.
  5. mains/ : Contains the main file to be called which is infnet.py This takes in the configuration file, and uses it to initiliaze which model, trainer, hyper-parameteres to choose, which parameters to save for tensorboard etc.
  6. models/ : Contains model definitions, each of which is a class. There are 4 such classes. EnergyNet, InferenceNet, FeatureNet, Spen. The first three are simple feed forward networks, and the last one is the actual model which is used and combines all the different networks together.
  7. trainers/ : Contains the trainer, which schedules the training, evaluation, tensorboard logging among different things.
  8. utils/ : Contains utility function like the process_config which is used to parse the configuration file.
  9. analysis.py, generate_configs.py, run.py: are all used for hyper-parameter tuning.

Config File Information

  1. exp_name : Name of the experiment
  2. data: Contains info about the data
    1. dataset: Name of the dataset
    2. data_dir: Path for the top directory of data
    3. splits : Splits for train, validation, test.
    4. embeddings: True/False. True if pre-trained embeddings are available.
    5. vocab : Same as above
    6. data_generator: Name of the data generator defined in data_loader.py.
  3. tensorboard_train: Set True to save tensorboard in the stage2
  4. tensorboard_infer: Same as above in stage3
  5. feature_size: Size of hidden layer in Feature Network
  6. label_measurements: Same as above in Energy Network
  7. type_vocab_size: Number of output labels.
  8. entities_vocab_size: Lookup table for embeddings.
  9. embeddings_tune: Set to true if embeddings vector to be updated.
  10. max_to_keep: Requiredby tensorflow saver that will be used in saving the checkpoints.
  11. num_epochs : Number of epochs in each stage
  12. train: Info about how to train
    1. diff_type: \nabla operator in the paper
    2. batch_size: batch size for training
    3. state_size: Not required. Kept for historical reasons.
    4. hidden_units: Hidden units in inference and feature net (depends on embeddings is true or false)
    5. lr_*: learning rate for optimization of corresponding variable
    6. lambda_*: lambda regularization for optimization of corresponding variable
    7. lambda_pretrain_bias: How much to weigh the pretrained network (another term in paper).
    8. wgan_mode: Improved WGAN penalty or not.
    9. lamb_wgan: regularization for wgan penalty.
  13. ssvm: Implementation of SPEN 2016 by Belanger. Not complete since we couldn't find implementation of entropic gradient descent.
    1. enable: To be ssvm or not to be
    2. steps: Number of optimization steps in ssvm
    3. eval: True if ssvm inference to be used.
    4. lr_inference: learning rate for ssvm inference
  14. eval_print: What all to print for evaluation
    1. f1: f1 score
    2. accuracy: accuracy
    3. energy: energy
    4. pretrain: energy / loss of pretrain
    5. infnet: energy / loss of inference network
    6. f1_score_mode: Set to examples to compute F1 score averaged over examples. Do label for F1 score averaged over labels. The paper does it over examples.
    7. threshold: Threshold adjusted on validation set
    8. time_taken: For time evaluation. Only training/inference step time. Not whole time.

Acknowledgements:

A major thanks to Lifu Tu and Kevin Gimpel (authors of the paper we have implemented) for sharing their Theano code and responding promptly to our queries on the paper. We thank Lifu for sharing his Theano Code. We also thank David Belanger for the Bookmarks dataset and his original SPEN implementation.

We also thank the authors of tensorflow template https://github.com/MrGemy95/Tensorflow-Project-Template which served as a starting point for our project.

References:

Lifu Tu and Kevin Gimpel. Learning approximate inference networks for structured prediction. CoRR, abs/1803.03376, 2018. URL http://arxiv.org/abs/1803.03376

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].