All Projects → somepago → saint

somepago / saint

Licence: Apache-2.0 License
The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Programming Languages

python
139335 projects - #7 most used programming language
Jupyter Notebook
11667 projects
shell
77523 projects

Projects that are alternatives of or similar to saint

TabFormer
Code & Data for "Tabular Transformers for Modeling Multivariate Time Series" (ICASSP, 2021)
Stars: ✭ 209 (+0%)
Mutual labels:  tabular-data, transformer
Multimodal Toolkit
Multimodal model for text and tabular data with HuggingFace transformers as building block for text data
Stars: ✭ 78 (-62.68%)
Mutual labels:  tabular-data, transformer
well-classified-examples-are-underestimated
Code for the AAAI 2022 publication "Well-classified Examples are Underestimated in Classification with Deep Neural Networks"
Stars: ✭ 21 (-89.95%)
Mutual labels:  transformer
Filipino-Text-Benchmarks
Open-source benchmark datasets and pretrained transformer models in the Filipino language.
Stars: ✭ 22 (-89.47%)
Mutual labels:  transformer
are-16-heads-really-better-than-1
Code for the paper "Are Sixteen Heads Really Better than One?"
Stars: ✭ 128 (-38.76%)
Mutual labels:  transformer
amrlib
A python library that makes AMR parsing, generation and visualization simple.
Stars: ✭ 107 (-48.8%)
Mutual labels:  transformer
semantic-document-relations
Implementation, trained models and result data for the paper "Pairwise Multi-Class Document Classification for Semantic Relations between Wikipedia Articles"
Stars: ✭ 21 (-89.95%)
Mutual labels:  transformer
deep-molecular-optimization
Molecular optimization by capturing chemist’s intuition using the Seq2Seq with attention and the Transformer
Stars: ✭ 60 (-71.29%)
Mutual labels:  transformer
linformer
Implementation of Linformer for Pytorch
Stars: ✭ 119 (-43.06%)
Mutual labels:  transformer
tableschema-go
A Go library for working with Table Schema.
Stars: ✭ 41 (-80.38%)
Mutual labels:  tabular-data
attention-is-all-you-need-paper
Implementation of Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems. 2017.
Stars: ✭ 97 (-53.59%)
Mutual labels:  transformer
Visual-Transformer-Paper-Summary
Summary of Transformer applications for computer vision tasks.
Stars: ✭ 51 (-75.6%)
Mutual labels:  transformer
zero-administration-inference-with-aws-lambda-for-hugging-face
Zero administration inference with AWS Lambda for 🤗
Stars: ✭ 19 (-90.91%)
Mutual labels:  transformer
pynmt
a simple and complete pytorch implementation of neural machine translation system
Stars: ✭ 13 (-93.78%)
Mutual labels:  transformer
node-sheets
read rows from google spreadsheet with google's sheets api
Stars: ✭ 16 (-92.34%)
Mutual labels:  tabular-data
PAML
Personalizing Dialogue Agents via Meta-Learning
Stars: ✭ 114 (-45.45%)
Mutual labels:  transformer
max-deeplab
Unofficial implementation of MaX-DeepLab for Instance Segmentation
Stars: ✭ 84 (-59.81%)
Mutual labels:  transformer
kosr
Korean speech recognition based on transformer (트랜스포머 기반 한국어 음성 인식)
Stars: ✭ 25 (-88.04%)
Mutual labels:  transformer
laravel5-jsonapi-dingo
Laravel5 JSONAPI and Dingo together to build APIs fast
Stars: ✭ 29 (-86.12%)
Mutual labels:  transformer
text-generation-transformer
text generation based on transformer
Stars: ✭ 36 (-82.78%)
Mutual labels:  transformer

This repository is the official PyTorch implementation of SAINT. Find the paper on arxiv

SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Overview

Requirements

We recommend using anaconda or miniconda for python. Our code has been tested with python=3.8 on linux.

Create a conda environment from the yml file and activate it.

conda env create -f saint_environment.yml
conda activate saint_env

Make sure the following requirements are met

  • torch>=1.8.1
  • torchvision>=0.9.1

Optional

We used wandb to update our logs. But it is optional.

conda install -c conda-forge wandb 

Training & Evaluation

In each of our experiments, we use a single Nvidia GeForce RTX 2080Ti GPU.

To train the model(s) in the paper, run this command:

python train.py --dset_id <openml_dataset_id> --task <task_name> --attentiontype <attention_type> 

Pretraining is useful when there are few training data samples. Sample code looks like this. (Use train_robust.py file for pretraining and robustness experiments)

python train_robust.py --dset_id <openml_dataset_id> --task <task_name> --attentiontype <attention_type>  --pretrain --pt_tasks <pretraining_task_touse> --pt_aug <augmentations_on_data_touse> --ssl_samples <Number_of_labeled_samples>

Arguments

  • --dset_id : Dataset id from OpenML. Works with all the datasets mentioned in the paper. Works with all OpenML datasets.
  • --task : The task we want to perform. Pick from 'regression','multiclass', or 'binary'.
  • --attentiontype : Variant of SAINT. 'col' refers to SAINT-s variant, 'row' is SAINT-i, and 'colrow' refers to SAINT.
  • --embedding_size : Size of the feature embeddings
  • --transformer_depth : Depth of the model. Number of stages.
  • --attention_heads : Number of attention heads in each Attention layer.
  • --cont_embeddings : Style of embedding continuous data.
  • --pretrain : To enable pretraining
  • --pt_tasks : Losses we want to use for pretraining. Multiple arguments can be passed.
  • --pt_aug : Types of data augmentations used in pretraining. Multiple arguments are allowed. We support only mixup and CutMix right now.
  • --ssl_samples : Number of labeled samples used in semi-supervised experiments.
  • --pt_projhead_style : Projection head style used in contrastive pipeline.
  • --nce_temp : Temperature used in contrastive loss function.
  • --active_log : To update the logs onto wandb. This is optional

Most of the hyperparameters are hardcoded in train.py file. For datasets with really high number of features, we suggest using smaller batchsize, lower embedding dimension and fewer number of heads.

Evaluation

We choose the best model by evaluating the model on validation dataset. The AuROC(for binary classification datasets), Accuracy (for multiclass classification datasets), and RMSE (for regression datasets) of the best model on test datasets is printed after training is completed. If wandb is enabled, they are logged to 'test_auroc_bestep', 'test_accuracy_bestep', 'test_rmse_bestep' variables.

What's new in this version?

  • Regression and multiclass classification models are added.
  • Data can be accessed directly from openml just by calling the id of the dataset.

Acknowledgements

We would like to thank the following public repo from which we borrowed various utilites.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Cite us

@article{somepalli2021saint,
  title={SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training},
  author={Somepalli, Gowthami and Goldblum, Micah and Schwarzschild, Avi and Bruss, C Bayan and Goldstein, Tom},
  journal={arXiv preprint arXiv:2106.01342},
  year={2021}
}

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].