All Projects → alfonmedela → triplet-loss-pytorch

alfonmedela / triplet-loss-pytorch

Licence: Apache-2.0 license
Highly efficient PyTorch version of the Semi-hard Triplet loss ⚡️

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to triplet-loss-pytorch

finetuner
Finetuning any DNN for better embedding on neural search tasks
Stars: ✭ 442 (+459.49%)
Mutual labels:  metric-learning, triplet-loss
MHCLN
Deep Metric and Hash Code Learning Network for Content Based Retrieval of Remote Sensing Images
Stars: ✭ 30 (-62.03%)
Mutual labels:  metric-learning, triplet-loss
GeDML
Generalized Deep Metric Learning.
Stars: ✭ 30 (-62.03%)
Mutual labels:  metric-learning, loss-functions
triplet
Re-implementation of tripletloss function in FaceNet
Stars: ✭ 94 (+18.99%)
Mutual labels:  triplet-loss, triplet
Addressing-Class-Imbalance-FL
This is the code for Addressing Class Imbalance in Federated Learning (AAAI-2021).
Stars: ✭ 62 (-21.52%)
Mutual labels:  loss-functions
cosine-ood-detector
Hyperparameter-Free Out-of-Distribution Detection Using Softmax of Scaled Cosine Similarity
Stars: ✭ 30 (-62.03%)
Mutual labels:  pytorch-implementation
RandLA-Net-pytorch
🍀 Pytorch Implementation of RandLA-Net (https://arxiv.org/abs/1911.11236)
Stars: ✭ 69 (-12.66%)
Mutual labels:  pytorch-implementation
proxy-synthesis
Official PyTorch implementation of "Proxy Synthesis: Learning with Synthetic Classes for Deep Metric Learning" (AAAI 2021)
Stars: ✭ 30 (-62.03%)
Mutual labels:  metric-learning
DocuNet
Code and dataset for the IJCAI 2021 paper "Document-level Relation Extraction as Semantic Segmentation".
Stars: ✭ 84 (+6.33%)
Mutual labels:  pytorch-implementation
MMD-GAN
Improving MMD-GAN training with repulsive loss function
Stars: ✭ 82 (+3.8%)
Mutual labels:  loss-functions
facenet-pytorch-glint360k
A PyTorch implementation of the 'FaceNet' paper for training a facial recognition model with Triplet Loss using the glint360k dataset. A pre-trained model using Triplet Loss is available for download.
Stars: ✭ 186 (+135.44%)
Mutual labels:  triplet-loss
deep-blueberry
If you've always wanted to learn about deep-learning but don't know where to start, then you might have stumbled upon the right place!
Stars: ✭ 17 (-78.48%)
Mutual labels:  pytorch-implementation
Awesome-Pytorch-Tutorials
Awesome Pytorch Tutorials
Stars: ✭ 23 (-70.89%)
Mutual labels:  pytorch-implementation
semi-supervised-paper-implementation
Reproduce some methods in semi-supervised papers.
Stars: ✭ 35 (-55.7%)
Mutual labels:  pytorch-implementation
TitleStylist
Source code for our "TitleStylist" paper at ACL 2020
Stars: ✭ 72 (-8.86%)
Mutual labels:  pytorch-implementation
MobileHumanPose
This repo is official PyTorch implementation of MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices(CVPRW 2021).
Stars: ✭ 206 (+160.76%)
Mutual labels:  pytorch-implementation
visual-compatibility
Context-Aware Visual Compatibility Prediction (https://arxiv.org/abs/1902.03646)
Stars: ✭ 92 (+16.46%)
Mutual labels:  metric-learning
ActiveSparseShifts-PyTorch
Implementation of Sparse Shift Layer and Active Shift Layer (3D, 4D, 5D tensors) for PyTorch(CPU,GPU)
Stars: ✭ 27 (-65.82%)
Mutual labels:  pytorch-implementation
Generative MLZSL
[TPAMI Under Submission] Generative Multi-Label Zero-Shot Learning
Stars: ✭ 37 (-53.16%)
Mutual labels:  pytorch-implementation
simple-cnaps
Source codes for "Improved Few-Shot Visual Classification" (CVPR 2020), "Enhancing Few-Shot Image Classification with Unlabelled Examples" (WACV 2022), and "Beyond Simple Meta-Learning: Multi-Purpose Models for Multi-Domain, Active and Continual Few-Shot Learning" (Neural Networks 2022 - in submission)
Stars: ✭ 88 (+11.39%)
Mutual labels:  metric-learning

Triplet SemiHardLoss

PyTorch semi hard triplet loss. Based on tensorflow addons version that can be found here. There is no need to create a siamese architecture with this implementation, it is as simple as following main_train_triplet.py cnn creation process!

The triplet loss is a great choice for classification problems with N_CLASSES >> N_SAMPLES_PER_CLASS. For example, face recognition problems.

The CNN architecture we use with triplet loss needs to be cut off before the classification layer. In addition, a L2 normalization layer has to be added.

Results on MNIST

I tested the triplet loss on the MNIST dataset. We can't compare directly to TF addons as I didn't run the experiment but this could be interesting from the point of view of performance. Here are the training logs if you want to compare results. Accuracy is not relevant and shouldn't be there as we are not training a classification model.

Phase 1

First we train last layer and batch normalization layers, getting close to 0.079 validation loss.

Phase 2

Finally, unfreezing all the layers it is possible to get close to 0.05 with enough training and hyperparmeter tuning.

Test

In order to test, there are two interesting options, training a classification model on top of the embeddings and plotting the train and test embeddings to see if same categories cluster together. The following figure contains the original 10,000 validation samples.

TSNE

We get an accuracy around 99.3% on validation by training a Linear SVM or a simple kNN. This repository is not focused on maximizing this accuracy by tweaking data augmentation, arquitecture and hyperparameters but on providing an effective implementation of triplet loss in torch. For more info on the state-of-the-art results on MNIST check out this amazing kaggle discussion.

Contact me with any question: [email protected] | alfonsomedela.com

Watch my latest TEDx talk: The medicine of the future

Foo

Donations ₿

BTC Wallet: 1DswCAGmXYQ4u2EWVJWitySM7Xo7SH4Wdf

IMPORTANT

If you're using fastai library, it will return an error when predicting the embeddings with learn.predict. It internally knows that your data has N classes and if the embedding vector has M dimensions, beeing M>N, and the predicted highest value is larger than N, that class does not exist and returns an error. So either create your prediction function or make a simple modification of the source code that will modify self.classes list length.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].