All Projects → nmichlo → disent

nmichlo / disent

Licence: MIT License
🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to disent

concept-based-xai
Library implementing state-of-the-art Concept-based and Disentanglement Learning methods for Explainable AI
Stars: ✭ 41 (+0%)
Mutual labels:  vae, disentanglement, disentangled-representations
Fun-with-MNIST
Playing with MNIST. Machine Learning. Generative Models.
Stars: ✭ 23 (-43.9%)
Mutual labels:  vae, autoencoders
Awesome Vaes
A curated list of awesome work on VAEs, disentanglement, representation learning, and generative models.
Stars: ✭ 418 (+919.51%)
Mutual labels:  vae, representation-learning
hydra-zen
Pythonic functions for creating and enhancing Hydra applications
Stars: ✭ 165 (+302.44%)
Mutual labels:  configurable, pytorch-lightning
Declutr
The corresponding code from our paper "DeCLUTR: Deep Contrastive Learning for Unsupervised Textual Representations". Do not hesitate to open an issue if you run into any trouble!
Stars: ✭ 111 (+170.73%)
Mutual labels:  metric-learning, representation-learning
srVAE
VAE with RealNVP prior and Super-Resolution VAE in PyTorch. Code release for https://arxiv.org/abs/2006.05218.
Stars: ✭ 56 (+36.59%)
Mutual labels:  vae, representation-learning
Srl Zoo
State Representation Learning (SRL) zoo with PyTorch - Part of S-RL Toolbox
Stars: ✭ 125 (+204.88%)
Mutual labels:  vae, representation-learning
continuous Bernoulli
There are C language computer programs about the simulator, transformation, and test statistic of continuous Bernoulli distribution. More than that, the book contains continuous Binomial distribution and continuous Trinomial distribution.
Stars: ✭ 22 (-46.34%)
Mutual labels:  vae, autoencoders
amr
Official adversarial mixup resynthesis repository
Stars: ✭ 31 (-24.39%)
Mutual labels:  representation-learning, autoencoders
TCE
This repository contains the code implementation used in the paper Temporally Coherent Embeddings for Self-Supervised Video Representation Learning (TCE).
Stars: ✭ 51 (+24.39%)
Mutual labels:  metric-learning, representation-learning
Pointglr
Global-Local Bidirectional Reasoning for Unsupervised Representation Learning of 3D Point Clouds (CVPR 2020)
Stars: ✭ 86 (+109.76%)
Mutual labels:  metric-learning, representation-learning
autoencoders tensorflow
Automatic feature engineering using deep learning and Bayesian inference using TensorFlow.
Stars: ✭ 66 (+60.98%)
Mutual labels:  representation-learning, autoencoders
Multi object datasets
Multi-object image datasets with ground-truth segmentation masks and generative factors.
Stars: ✭ 121 (+195.12%)
Mutual labels:  datasets, representation-learning
Disentangling Vae
Experiments for understanding disentanglement in VAE latent representations
Stars: ✭ 398 (+870.73%)
Mutual labels:  vae, representation-learning
Codesearchnet
Datasets, tools, and benchmarks for representation learning of code.
Stars: ✭ 1,378 (+3260.98%)
Mutual labels:  datasets, representation-learning
HiCMD
[CVPR2020] Hi-CMD: Hierarchical Cross-Modality Disentanglement for Visible-Infrared Person Re-Identification
Stars: ✭ 64 (+56.1%)
Mutual labels:  representation-learning, disentanglement
DisCont
Code for the paper "DisCont: Self-Supervised Visual Attribute Disentanglement using Context Vectors".
Stars: ✭ 13 (-68.29%)
Mutual labels:  disentanglement, disentangled-representations
linguistic-style-transfer-pytorch
Implementation of "Disentangled Representation Learning for Non-Parallel Text Style Transfer(ACL 2019)" in Pytorch
Stars: ✭ 55 (+34.15%)
Mutual labels:  disentanglement, disentangled-representations
ladder-vae-pytorch
Ladder Variational Autoencoders (LVAE) in PyTorch
Stars: ✭ 59 (+43.9%)
Mutual labels:  vae, representation-learning
opendatasets
A Python library for downloading datasets from Kaggle, Google Drive, and other online sources.
Stars: ✭ 161 (+292.68%)
Mutual labels:  datasets

🧶 Disent

A modular disentangled representation learning framework built with PyTorch Lightning

license python versions pypi version tests status

Visit the docs for more info, or browse the releases.

Contributions are welcome!


Table Of Contents


Overview

Disent is a modular disentangled representation learning framework for auto-encoders, built upon PyTorch-Lightning. This framework consists of various composable components that can be used to build and benchmark various disentanglement vision tasks.

The name of the framework is derived from both disentanglement and scientific dissent.

Get started with disent by installing it with $pip install disent or cloning this repository.

Goals

Disent aims to fill the following criteria:

  1. Provide high quality, readable, consistent and easily comparable implementations of frameworks
  2. Highlight difference between framework implementations by overriding hooks and minimising duplicate code
  3. Use best practice eg. torch.distributions
  4. Be extremely flexible & configurable
  5. Support low memory systems

Citing Disent

Please use the following citation if you use Disent in your own research:

@Misc{Michlo2021Disent,
  author =       {Nathan Juraj Michlo},
  title =        {Disent - A modular disentangled representation learning framework for pytorch},
  howpublished = {Github},
  year =         {2021},
  url =          {https://github.com/nmichlo/disent}
}

Architecture

The disent module structure:

  • disent.dataset: dataset wrappers, datasets & sampling strategies
    • disent.dataset.data: raw datasets
    • disent.dataset.sampling: sampling strategies for DisentDataset when multiple elements are required by frameworks, eg. for triplet loss
    • disent.dataset.transform: common data transforms and augmentations
    • disent.dataset.wrapper: wrapped datasets are no longer ground-truth datasets, these may have some elements masked out. We can still unwrap these classes to obtain the original datasets for benchmarking.
  • disent.frameworks: frameworks, including Auto-Encoders and VAEs
    • disent.frameworks.ae: Auto-Encoder based frameworks
    • disent.frameworks.vae: Variational Auto-Encoder based frameworks
  • disent.metrics: metrics for evaluating disentanglement using ground truth datasets
  • disent.model: common encoder and decoder models used for VAE research
  • disent.nn: torch components for building models including layers, transforms, losses and general maths
  • disent.schedule: annealing schedules that can be registered to a framework
  • disent.util: helper classes, functions, callbacks, anything unrelated to a pytorch system/model/framework.

Please Note The API Is Still Unstable ⚠️

Disent is still under active development. Features and APIs are mostly stable but may change! A limited set of tests currently exist which will be expanded upon in time.

Hydra Experiment Directories

Easily run experiments with hydra config, these files are not available from pip install.

  • experiment/run.py: entrypoint for running basic experiments with hydra config
  • experiment/config/config.yaml: main configuration file, this is probably what you want to edit!
  • experiment/config: root folder for hydra config files
  • experiment/util: various helper code for experiments

Features

Disent includes implementations of modules, metrics and datasets from various papers. Please note that items marked with a "🧵" are introduced in and are unique to disent!

Frameworks

Many popular disentanglement frameworks still need to be added, please submit an issue if you have a request for an additional framework.

todo

  • FactorVAE
  • GroupVAE
  • MLVAE

Metrics

Some popular metrics still need to be added, please submit an issue if you wish to add your own, or you have a request.

todo

Datasets

Various common datasets used in disentanglement research are included, with hash verification and automatic chunk-size optimization of underlying hdf5 formats for low-memory disk-based access.

  • Ground Truth:

    • Cars3D
    • dSprites
    • MPI3D
    • SmallNORB
    • Shapes3D
  • Ground Truth Synthetic:

    • 🧵 XYObject: A simplistic version of dSprites with a single square.
    • 🧵 XYObjectShaded: Exact same dataset as XYObject, but ground truth factors have a different representation
    • 🧵 DSpritesImagenet: Version of DSprite with foreground or background deterministically masked out with tiny-imagenet data

    XYObject Dataset Factor Traversals

    Input Transforms + Input/Target Augmentations

    • Input based transforms are supported.
    • Input and Target CPU and GPU based augmentations are supported.

Schedules & Annealing

Hyper-parameter annealing is supported through the use of schedules. The currently implemented schedules include:

  • Linear Schedule
  • Cyclic Schedule
  • Cosine Wave Schedule
  • Various other wrapper schedules

Examples

Python Example

The following is a basic working example of disent that trains a BetaVAE with a cyclic beta schedule and evaluates the trained model with various metrics.

💾 Basic Example

import os
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import SingleSampler
from disent.dataset.transform import ToImgTensorF32
from disent.frameworks.vae import BetaVae
from disent.metrics import metric_dci
from disent.metrics import metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64
from disent.model.ae import EncoderConv64
from disent.schedule import CyclicSchedule

# create the dataset & dataloaders
# - ToImgTensorF32 transforms images from numpy arrays to tensors and performs checks
data = XYObjectData()
dataset = DisentDataset(dataset=data, sampler=SingleSampler(), transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True, num_workers=os.cpu_count())

# create the BetaVAE model
# - adjusting the beta, learning rate, and representation size.
module = BetaVae(
  model=AutoEncoder(
    # z_multiplier is needed to output mu & logvar when parameterising normal distribution
    encoder=EncoderConv64(x_shape=data.x_shape, z_size=10, z_multiplier=2),
    decoder=DecoderConv64(x_shape=data.x_shape, z_size=10),
  ),
  cfg=BetaVae.cfg(
    optimizer='adam',
    optimizer_kwargs=dict(lr=1e-3),
    loss_reduction='mean_sum',
    beta=4,
  )
)

# cyclic schedule for target 'beta' in the config/cfg. The initial value from the
# config is saved and multiplied by the ratio from the schedule on each step.
# - based on: https://arxiv.org/abs/1903.10145
module.register_schedule(
  'beta', CyclicSchedule(
    period=1024,  # repeat every: trainer.global_step % period
  )
)

# train model
# - for 2048 batches/steps
trainer = pl.Trainer(
  max_steps=2048, gpus=1 if torch.cuda.is_available() else None, logger=False, checkpoint_callback=False
)
trainer.fit(module, dataloader)

# compute disentanglement metrics
# - we cannot guarantee which device the representation is on
# - this will take a while to run
get_repr = lambda x: module.encode(x.to(module.device))

metrics = {
  **metric_dci(dataset, get_repr, num_train=1000, num_test=500, show_progress=True),
  **metric_mig(dataset, get_repr, num_train=2000),
}

# evaluate
print('metrics:', metrics)

Visit the docs for more examples!

Hydra Config Example

The entrypoint for basic experiments is experiment/run.py.

Some configuration will be required, but basic experiments can be adjusted by modifying the Hydra Config 1.0 files in experiment/config (Please note that hydra 1.1 is not yet supported).

Modifying the main experiment/config/config.yaml is all you need for most basic experiments. The main config file contains a defaults list with entries corresponding to yaml configuration files (config options) in the subfolders (config groups) in experiment/config/<config_group>/<option>.yaml.

💾 Config Defaults Example

defaults:
  # data
  - sampling: default__bb
  - dataset: xyobject
  - augment: none
  # system
  - framework: adavae_os
  - model: vae_conv64
  # training
  - optimizer: adam
  - schedule: beta_cyclic
  - metrics: fast
  - run_length: short
  # logs
  - run_callbacks: vis
  - run_logging: wandb
  # runtime
  - run_location: local
  - run_launcher: local
  - run_action: train

# <rest of config.yaml left out>
...

Easily modify any of these values to adjust how the basic experiment will be run. For example, change framework: adavae to framework: betavae, or change the dataset from xyobject to shapes3d. Add new options by adding new yaml files in the config group folders.

Weights and Biases is supported by changing run_logging: none to run_logging: wandb. However, you will need to login from the command line. W&B logging supports visualisations of latent traversals.


Why?

  • Created as part of my Computer Science MSc scheduled for completion in 2021.
  • I needed custom high quality implementations of various VAE's.
  • A pytorch version of disentanglement_lib.
  • I didn't have time to wait for Weakly-Supervised Disentanglement Without Compromises to release their code as part of disentanglement_lib. (As of September 2020 it has been released, but has unresolved discrepencies).
  • disentanglement_lib still uses outdated Tensorflow 1.0, and the flow of data is unintuitive because of its use of Gin Config.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].