All Projects → VcampSoldiers → Swin-Transformer-Tensorflow

VcampSoldiers / Swin-Transformer-Tensorflow

Licence: MIT License
Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Swin-Transformer-Tensorflow

tensorflow-ml-nlp-tf2
텐서플로2와 머신러닝으로 시작하는 자연어처리 (로지스틱회귀부터 BERT와 GPT3까지) 실습자료
Stars: ✭ 245 (+444.44%)
Mutual labels:  tf2, transformer
Mmsegmentation
OpenMMLab Semantic Segmentation Toolbox and Benchmark.
Stars: ✭ 2,875 (+6288.89%)
Mutual labels:  transformer, swin-transformer
keras cv attention models
Keras/Tensorflow attention models including beit,botnet,CMT,CoaT,CoAtNet,convnext,cotnet,davit,efficientdet,efficientnet,fbnet,gmlp,halonet,lcnet,levit,mlp-mixer,mobilevit,nfnets,regnet,resmlp,resnest,resnext,resnetd,swin,tinynet,uniformer,volo,wavemlp,yolor,yolox
Stars: ✭ 159 (+253.33%)
Mutual labels:  tf2, tf
manning tf2 in action
The official code repository for "TensorFlow in Action" by Manning.
Stars: ✭ 61 (+35.56%)
Mutual labels:  tf2, tf
TransMorph Transformer for Medical Image Registration
TransMorph: Transformer for Unsupervised Medical Image Registration (PyTorch)
Stars: ✭ 130 (+188.89%)
Mutual labels:  transformer, swin-transformer
transformer-tensorflow2.0
transformer in tensorflow 2.0
Stars: ✭ 53 (+17.78%)
Mutual labels:  tf2, transformer
keras efficientnet v2
self defined efficientnetV2 according to official version. Including converted ImageNet/21K/21k-ft1k weights.
Stars: ✭ 56 (+24.44%)
Mutual labels:  tf2, tf
attention-is-all-you-need-paper
Implementation of Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems. 2017.
Stars: ✭ 97 (+115.56%)
Mutual labels:  transformer
saint
The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training
Stars: ✭ 209 (+364.44%)
Mutual labels:  transformer
YOLOv5-Multibackbone-Compression
YOLOv5 Series Multi-backbone(TPH-YOLOv5, Ghostnet, ShuffleNetv2, Mobilenetv3Small, EfficientNetLite, PP-LCNet, SwinTransformer YOLO), Module(CBAM, DCN), Pruning (EagleEye, Network Slimming) and Quantization (MQBench) Compression Tool Box.
Stars: ✭ 307 (+582.22%)
Mutual labels:  swin-transformer
pynmt
a simple and complete pytorch implementation of neural machine translation system
Stars: ✭ 13 (-71.11%)
Mutual labels:  transformer
GameTracking-TF2
📥 Game Tracker: Team Fortress 2
Stars: ✭ 59 (+31.11%)
Mutual labels:  tf2
charformer-pytorch
Implementation of the GBST block from the Charformer paper, in Pytorch
Stars: ✭ 74 (+64.44%)
Mutual labels:  transformer
trapper
State-of-the-art NLP through transformer models in a modular design and consistent APIs.
Stars: ✭ 28 (-37.78%)
Mutual labels:  transformer
TextPruner
A PyTorch-based model pruning toolkit for pre-trained language models
Stars: ✭ 94 (+108.89%)
Mutual labels:  transformer
ROS
ROS机器人操作系统 学习(写于2020年夏)
Stars: ✭ 102 (+126.67%)
Mutual labels:  tf
galerkin-transformer
[NeurIPS 2021] Galerkin Transformer: a linear attention without softmax
Stars: ✭ 111 (+146.67%)
Mutual labels:  transformer
tf2-rich-presence
Discord Rich Presence for Team Fortress 2
Stars: ✭ 30 (-33.33%)
Mutual labels:  tf2
linformer
Implementation of Linformer for Pytorch
Stars: ✭ 119 (+164.44%)
Mutual labels:  transformer
flexible-yolov5
More readable and flexible yolov5 with more backbone(resnet, shufflenet, moblienet, efficientnet, hrnet, swin-transformer) and (cbam,dcn and so on), and tensorrt
Stars: ✭ 282 (+526.67%)
Mutual labels:  swin-transformer

Swin-Transformer-Tensorflow

A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" to TensorFlow 2.

The official Pytorch implementation can be found here.

Introduction:

Swin Transformer Architecture Diagram

Swin Transformer (the name Swin stands for Shifted window) is initially described in arxiv, which capably serves as a general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection.

Swin Transformer achieves strong performance on COCO object detection (58.7 box AP and 51.1 mask AP on test-dev) and ADE20K semantic segmentation (53.5 mIoU on val), surpassing previous models by a large margin.

Usage:

1. To Run a Pre-trained Swin Transformer

Swin-T:

python main.py --cfg configs/swin_tiny_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-S:

python main.py --cfg configs/swin_small_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-B:

python main.py --cfg configs/swin_base_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

The possible options for cfg and weights_type are:

cfg weights_type 22K model 1K Model
configs/swin_tiny_patch4_window7_224.yaml imagenet_1k - github
configs/swin_small_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22k github -
configs/swin_base_patch4_window12_384.yaml imagenet_22k github -
configs/swin_large_patch4_window7_224.yaml imagenet_22k github -
configs/swin_large_patch4_window12_384.yaml imagenet_22k github -

2. Create custom models

To create a custom classification model:

import argparse

import tensorflow as tf

from config import get_config
from models.build import build_model

parser = argparse.ArgumentParser('Custom Swin Transformer')

parser.add_argument(
    '--cfg',
    type=str,
    metavar="FILE",
    help='path to config file',
    default="CUSTOM_YAML_FILE_PATH"
)
parser.add_argument(
    '--resume',
    type=int,
    help='Whether or not to resume training from pretrained weights',
    choices={0, 1},
    default=1,
)
parser.add_argument(
    '--weights_type',
    type=str,
    help='Type of pretrained weight file to load including number of classes',
    choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"},
    default="imagenet_1k",
)

args = parser.parse_args()
custom_config = get_config(args, include_top=False)

swin_transformer = tf.keras.Sequential([
    build_model(config=custom_config, load_pretrained=args.resume, weights_type=args.weights_type),
    tf.keras.layers.Dense(CUSTOM_NUM_CLASSES)
)

Model ouputs are logits, so don't forget to include softmax in training/inference!!

You can easily customize the model configs with custom YAML files. Predefined YAML files provided by Microsoft are located in the configs directory.

3. Convert PyTorch pretrained weights into Tensorflow checkpoints

We provide a python script with which we convert official PyTorch weights into Tensorflow checkpoints.

$ python convert_weights.py --cfg config_file --weights the_path_to_pytorch_weights --weights_type type_of_pretrained_weights --output the_path_to_output_tf_weights

TODO:

  • Translate model code over to TensorFlow
  • Load PyTorch pretrained weights into TensorFlow model
  • Write trainer code
  • Reproduce results presented in paper
    • Object Detection
  • Reproduce training efficiency of official code in TensorFlow

Citations:

@misc{liu2021swin,
      title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, 
      author={Ze Liu and Yutong Lin and Yue Cao and Han Hu and Yixuan Wei and Zheng Zhang and Stephen Lin and Baining Guo},
      year={2021},
      eprint={2103.14030},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].