All Projects → lucidrains → Byol Pytorch

lucidrains / Byol Pytorch

Licence: mit
Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Byol Pytorch

Rex Gym
OpenAI Gym environments for an open-source quadruped robot (SpotMicro)
Stars: ✭ 684 (-10.94%)
Mutual labels:  artificial-intelligence
Easypr
An easy, flexible, and accurate plate recognition project for Chinese licenses in unconstrained situations.
Stars: ✭ 6,046 (+687.24%)
Mutual labels:  artificial-intelligence
Machine learning refined
Notes, examples, and Python demos for the textbook "Machine Learning Refined" (published by Cambridge University Press).
Stars: ✭ 750 (-2.34%)
Mutual labels:  artificial-intelligence
Sciblog support
Support content for my blog
Stars: ✭ 694 (-9.64%)
Mutual labels:  artificial-intelligence
Tensorflow 2.x Tutorials
TensorFlow 2.x version's Tutorials and Examples, including CNN, RNN, GAN, Auto-Encoders, FasterRCNN, GPT, BERT examples, etc. TF 2.0版入门实例代码,实战教程。
Stars: ✭ 6,088 (+692.71%)
Mutual labels:  artificial-intelligence
Lycheejs
🌱 Next-Gen AI-Assisted Isomorphic Application Engine for Embedded, Console, Mobile, Server and Desktop
Stars: ✭ 728 (-5.21%)
Mutual labels:  artificial-intelligence
Querido Diario
📰 Brazilian government gazettes, accessible to everyone.
Stars: ✭ 681 (-11.33%)
Mutual labels:  artificial-intelligence
Notes
Resources to learn more about Machine Learning and Artificial Intelligence
Stars: ✭ 766 (-0.26%)
Mutual labels:  artificial-intelligence
Serpentai
Game Agent Framework. Helping you create AIs / Bots that learn to play any game you own!
Stars: ✭ 6,104 (+694.79%)
Mutual labels:  artificial-intelligence
Gans In Action
Companion repository to GANs in Action: Deep learning with Generative Adversarial Networks
Stars: ✭ 748 (-2.6%)
Mutual labels:  artificial-intelligence
Machine Learning For Software Engineers
A complete daily plan for studying to become a machine learning engineer.
Stars: ✭ 25,562 (+3228.39%)
Mutual labels:  artificial-intelligence
Ai Series
📚 [.md & .ipynb] Series of Artificial Intelligence & Deep Learning, including Mathematics Fundamentals, Python Practices, NLP Application, etc. 💫 人工智能与深度学习实战,数理统计篇 | 机器学习篇 | 深度学习篇 | 自然语言处理篇 | 工具实践 Scikit & Tensoflow & PyTorch 篇 | 行业应用 & 课程笔记
Stars: ✭ 702 (-8.59%)
Mutual labels:  artificial-intelligence
Imageai
A python library built to empower developers to build applications and systems with self-contained Computer Vision capabilities
Stars: ✭ 6,734 (+776.82%)
Mutual labels:  artificial-intelligence
Data Science Interview Resources
A repository listing out the potential sources which will help you in preparing for a Data Science/Machine Learning interview. New resources added frequently.
Stars: ✭ 690 (-10.16%)
Mutual labels:  artificial-intelligence
Ssvm
SSVM is a high performance, extensible, and hardware optimized WebAssembly Virtual Machine for cloud, AI, and blockchain applications.
Stars: ✭ 751 (-2.21%)
Mutual labels:  artificial-intelligence
Translate
Translate - a PyTorch Language Library
Stars: ✭ 684 (-10.94%)
Mutual labels:  artificial-intelligence
Fairlearn
A Python package to assess and improve fairness of machine learning models.
Stars: ✭ 723 (-5.86%)
Mutual labels:  artificial-intelligence
Sincnet
SincNet is a neural architecture for efficiently processing raw audio samples.
Stars: ✭ 764 (-0.52%)
Mutual labels:  artificial-intelligence
Nupic
Numenta Platform for Intelligent Computing is an implementation of Hierarchical Temporal Memory (HTM), a theory of intelligence based strictly on the neuroscience of the neocortex.
Stars: ✭ 6,291 (+719.14%)
Mutual labels:  artificial-intelligence
Awesome Artificial Intelligence
A curated list of Artificial Intelligence (AI) courses, books, video lectures and papers.
Stars: ✭ 6,516 (+748.44%)
Mutual labels:  artificial-intelligence

Bootstrap Your Own Latent (BYOL), in Pytorch

PyPI version

Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs.

This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data.

Update 1: There is now new evidence that batch normalization is key to making this technique work well

Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work

Update 3: Finally, we have some analysis for why this works

Yannic Kilcher's excellent explanation

Now go save your organization from having to pay for labels :)

Install

$ pip install byol-pytorch

Usage

Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks.

BYOL → SimSiam

A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the use_momentum flag to False. You will no longer need to invoke update_moving_average if you go this route as shown in the example below.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    use_momentum = False       # turn off momentum in the target encoder
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

Advanced

While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class.

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    projection_size = 256,           # the projection size
    projection_hidden_size = 4096,   # the hidden dimension of the MLP for both the projection and prediction
    moving_average_decay = 0.99      # the moving average decay factor for the target encoder, already set at what paper recommends
)

By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the augment_fn keyword.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn
)

In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

augment_fn2 = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip(),
    kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5))
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn,
    augment_fn2 = augment_fn2,
)

To fetch the embeddings or the projections, you simply have to pass in a return_embeddings = True flag to the BYOL learner instance

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

imgs = torch.randn(2, 3, 256, 256)
projection, embedding = learner(imgs, return_embedding = True)

Alternatives

If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning.

https://github.com/lucidrains/pixel-level-contrastive-learning

Citation

@misc{grill2020bootstrap,
    title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning},
    author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko},
    year = {2020},
    eprint = {2006.07733},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{chen2020exploring,
    title={Exploring Simple Siamese Representation Learning}, 
    author={Xinlei Chen and Kaiming He},
    year={2020},
    eprint={2011.10566},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].