All Projects → EleutherAI → Dalle Mtf

EleutherAI / Dalle Mtf

Licence: mit
Open-AI's DALL-E for large scale training in mesh-tensorflow.

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Dalle Mtf

AutoEncoders
Variational autoencoder, denoising autoencoder and other variations of autoencoders implementation in keras
Stars: ✭ 14 (-94.4%)
Mutual labels:  variational-autoencoder
Articutapi
API of Articut 中文斷詞 (兼具語意詞性標記):「斷詞」又稱「分詞」,是中文資訊處理的基礎。Articut 不用機器學習,不需資料模型,只用現代白話中文語法規則,即能達到 SIGHAN 2005 F1-measure 94% 以上,Recall 96% 以上的成績。
Stars: ✭ 252 (+0.8%)
Mutual labels:  artificial-intelligence
Iamdinosaur
🦄 An Artificial Inteligence to teach Google's Dinosaur to jump cactus
Stars: ✭ 2,767 (+1006.8%)
Mutual labels:  artificial-intelligence
tt-vae-gan
Timbre transfer with variational autoencoding and cycle-consistent adversarial networks. Able to transfer the timbre of an audio source to that of another.
Stars: ✭ 37 (-85.2%)
Mutual labels:  variational-autoencoder
intro dgm
An Introduction to Deep Generative Modeling: Examples
Stars: ✭ 124 (-50.4%)
Mutual labels:  variational-autoencoder
Fakenewscorpus
A dataset of millions of news articles scraped from a curated list of data sources.
Stars: ✭ 255 (+2%)
Mutual labels:  artificial-intelligence
Variational-NMT
Variational Neural Machine Translation System
Stars: ✭ 37 (-85.2%)
Mutual labels:  variational-autoencoder
Atlas
An Open Source, Self-Hosted Platform For Applied Deep Learning Development
Stars: ✭ 259 (+3.6%)
Mutual labels:  artificial-intelligence
class-incremental-learning
PyTorch implementation of a VAE-based generative classifier, as well as other class-incremental learning methods that do not store data (DGR, BI-R, EWC, SI, CWR, CWR+, AR1, the "labels trick", SLDA).
Stars: ✭ 30 (-88%)
Mutual labels:  variational-autoencoder
Ai Job Notes
AI算法岗求职攻略(涵盖准备攻略、刷题指南、内推和AI公司清单等资料)
Stars: ✭ 3,191 (+1176.4%)
Mutual labels:  artificial-intelligence
srVAE
VAE with RealNVP prior and Super-Resolution VAE in PyTorch. Code release for https://arxiv.org/abs/2006.05218.
Stars: ✭ 56 (-77.6%)
Mutual labels:  variational-autoencoder
classifying-vae-lstm
music generation with a classifying variational autoencoder (VAE) and LSTM
Stars: ✭ 27 (-89.2%)
Mutual labels:  variational-autoencoder
Amazing Python Scripts
🚀 Curated collection of Amazing Python scripts from Basics to Advance with automation task scripts.
Stars: ✭ 229 (-8.4%)
Mutual labels:  artificial-intelligence
CHyVAE
Code for our paper -- Hyperprior Induced Unsupervised Disentanglement of Latent Representations (AAAI 2019)
Stars: ✭ 18 (-92.8%)
Mutual labels:  variational-autoencoder
Es Dev Stack
An on-premises, bare-metal solution for deploying GPU-powered applications in containers
Stars: ✭ 257 (+2.8%)
Mutual labels:  artificial-intelligence
lagvae
Lagrangian VAE
Stars: ✭ 27 (-89.2%)
Mutual labels:  variational-autoencoder
S Vae Pytorch
Pytorch implementation of Hyperspherical Variational Auto-Encoders
Stars: ✭ 255 (+2%)
Mutual labels:  variational-autoencoder
Polyaxon
Machine Learning Platform for Kubernetes (MLOps tools for experimentation and automation)
Stars: ✭ 2,966 (+1086.4%)
Mutual labels:  artificial-intelligence
Machine Learning And Ai In Trading
Applying Machine Learning and AI Algorithms applied to Trading for better performance and low Std.
Stars: ✭ 258 (+3.2%)
Mutual labels:  artificial-intelligence
Textbox
TextBox is an open-source library for building text generation system.
Stars: ✭ 257 (+2.8%)
Mutual labels:  variational-autoencoder

DALL-E in Mesh-Tensorflow [WIP]

Open-AI's DALL-E in Mesh-Tensorflow.

If this is similarly efficient to GPT-Neo, this repo should be able to train models up to, and larger than, the size of Open-AI's DALL-E (12B params).

No pretrained models... Yet.

Thanks to Ben Wang for the tf vae implementation as well as getting the mtf version working, and Aran Komatsuzaki for help building the mtf VAE and input pipeline.

Setup

git clone https://github.com/EleutherAI/GPTNeo
cd GPTNeo
pip3 install -r requirements.txt

Training Setup

Runs on TPUs, untested on GPUs but should work in theory. The example configs are designed to run on a TPU v3-32 pod.

To set up TPUs, sign up for Google Cloud Platform, and create a storage bucket.

Create your VM through a google shell (https://ssh.cloud.google.com/) with ctpu up --vm-only so that it can connect to your Google bucket and TPUs and setup the repo as above.

VAE pretraining

DALLE needs a pretrained VAE to compress images to tokens. To run the VAE pretraining, adjust the params in configs/vae_example.json to a glob path pointing to a dataset of jpgs, and adjust image size to the appropriate size.

  "dataset": {
    "train_path": "gs://neo-datasets/CIFAR-10-images/train/**/*.jpg",
    "eval_path": "gs://neo-datasets/CIFAR-10-images/test/**/*.jpg",
    "image_size": 32
  }

Once this is all set up, create your TPU, then run:

python train_vae_tf.py --tpu your_tpu_name --model vae_example

The training logs image tensors and loss values, to check progress, you can run:

tensorboard --logdir your_model_dir

Dataset Creation [DALL-E]

Once the VAE is pretrained, you can move on to DALL-E.

Currently we are training on a dummy dataset. A public, large-scale dataset for DALL-E is in the works. In the meantime, to generate some dummy data, run:

python src/data/create_tfrecords.py

This should download CIFAR-10, and generate some random captions to act as text inputs.

Custom datasets should be formatted in a folder, with a jsonl file in the root folder containing caption data and paths to the respective images, as follows:

Folder structure:

        data_folder
            jsonl_file
            folder_1
                img1
                img2
                ...
            folder_2
                img1
                img2
                ...
            ...

jsonl structure:
    {"image_path": folder_1/img1, "caption": "some words"}
    {"image_path": folder_2/img2, "caption": "more words"}
    ...

you can then use the create_paired_dataset function in src/data/create_tfrecords.py to encode the dataset into tfrecords for use in training.

Once the dataset is created, copy it over to your bucket with gsutil:

gsutil cp -r DALLE-tfrecords gs://neo-datasets/

And finally, run training with

python train_dalle.py --tpu your_tpu_name --model dalle_example

Config Guide

VAE:

{
  "model_type": "vae",
  "dataset": {
    "train_path": "gs://neo-datasets/CIFAR-10-images/train/**/*.jpg", # glob path to training images
    "eval_path": "gs://neo-datasets/CIFAR-10-images/test/**/*.jpg", # glob path to eval images
    "image_size": 32 # size of images (all images will be cropped / padded to this size)
  },
  "train_batch_size": 32, 
  "eval_batch_size": 32,
  "predict_batch_size": 32,
  "steps_per_checkpoint": 1000, # how often to save a checkpoint
  "iterations": 500, # number of batches to infeed to the tpu at a time. Must be < steps_per_checkpoint
  "train_steps": 100000, # total training steps
  "eval_steps": 0, # run evaluation for this many steps every steps_per_checkpoint
  "model_path": "gs://neo-models/vae_test2/", # directory in which to save the model
  "mesh_shape": "data:16,model:2", # mapping of processors to named dimensions - see mesh-tensorflow repo for more info
  "layout": "batch_dim:data", # which named dimensions of the model to split across the mesh - see mesh-tensorflow repo for more info
  "num_tokens": 512, # vocab size
  "dim": 512, 
  "hidden_dim": 64, # size of hidden dim
  "n_channels": 3, # number of input channels
  "bf_16": false, # if true, the model is trained with bfloat16 precision
  "lr": 0.001, # learning rate [by default learning rate starts at this value, then decays to 10% of this value over the course of the training]
  "num_layers": 3, # number of blocks in the encoder / decoder
  "train_gumbel_hard": true, # whether to use hard or soft gumbel_softmax
  "eval_gumbel_hard": true
}

DALL-E:

{
  "model_type": "dalle",
  "dataset": {
    "train_path": "gs://neo-datasets/DALLE-tfrecords/*.tfrecords", # glob path to tfrecords data
    "eval_path": "gs://neo-datasets/DALLE-tfrecords/*.tfrecords",
    "image_size": 32 # size of images (all images will be cropped / padded to this size)
  },
  "train_batch_size": 32, # see above
  "eval_batch_size": 32,
  "predict_batch_size": 32,
  "steps_per_checkpoint": 1000,
  "iterations": 500,
  "train_steps": 100000,
  "predict_steps": 0,
  "eval_steps": 0,
  "n_channels": 3,
  "bf_16": false,
  "lr": 0.001,
  "model_path": "gs://neo-models/dalle_test/",
  "mesh_shape": "data:16,model:2",
  "layout": "batch_dim:data",
  "n_embd": 512, # size of embedding dim
  "text_vocab_size": 50258, # vocabulary size of the text tokenizer
  "image_vocab_size": 512, # vocabulary size of the vae - should equal num_tokens above
  "text_seq_len": 256, # length of text inputs (all inputs longer / shorter will be truncated / padded)
  "n_layers": 6, 
  "n_heads": 4, # number of attention heads. For best performance, n_embd / n_heads should equal 128
  "vae_model": "vae_example" # path to or name of vae model config
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].