All Projects → mit-han-lab → Lite Transformer

mit-han-lab / Lite Transformer

Licence: other
[ICLR 2020] Lite Transformer with Long-Short Range Attention

Programming Languages

python
139335 projects - #7 most used programming language

Lite Transformer with Long-Short Range Attention

@inproceedings{Wu2020LiteTransformer,
  title={Lite Transformer with Long-Short Range Attention},
  author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2020}
}

Overview

We release the PyTorch code for the Lite Transformer. [Paper|Website|Slides]: overview

Consistent Improvement by Tradeoff Curves

tradeoff

Save 20000x Searching Cost of Evolved Transformer

et

Further Compress Transformer by 18.2x

compression

How to Use

Prerequisite

  • Python version >= 3.6
  • PyTorch version >= 1.0.0
  • configargparse >= 0.14
  • For training new models, you'll also need an NVIDIA GPU and NCCL

Installation

  1. Codebase

    To install fairseq from source and develop locally:

    pip install --editable .
    
  2. Costumized Modules

    We also need to build the lightconv and dynamicconv for GPU support.

    Lightconv_layer

    cd fairseq/modules/lightconv_layer
    python cuda_function_gen.py
    python setup.py install
    

    Dynamicconv_layer

    cd fairseq/modules/dynamicconv_layer
    python cuda_function_gen.py
    python setup.py install
    

Data Preparation

IWSLT'14 De-En

We follow the data preparation in fairseq. To download and preprocess the data, one can run

bash configs/iwslt14.de-en/prepare.sh

WMT'14 En-Fr

We follow the data pre-processing in fairseq. To download and preprocess the data, one can run

bash configs/wmt14.en-fr/prepare.sh

WMT'16 En-De

We follow the data pre-processing in fairseq. One should first download the preprocessed data from the Google Drive provided by Google. To binarized the data, one can run

bash configs/wmt16.en-de/prepare.sh [path to the downloaded zip file]

WIKITEXT-103

As the language model task has many additional codes, we place it in another branch: language-model. We follow the data pre-processing in fairseq. To download and preprocess the data, one can run

git checkout language-model
bash configs/wikitext-103/prepare.sh

Testing

For example, to test the models on WMT'14 En-Fr, one can run

configs/wmt14.en-fr/test.sh [path to the model checkpoints] [gpu-id] [test|valid]

For instance, to evaluate Lite Transformer on GPU 0 (with the BLEU score on test set of WMT'14 En-Fr), one can run

configs/wmt14.en-fr/test.sh embed496/ 0 test

We provide several pretrained models at the bottom. You can download the model and extract the file by

tar -xzvf [filename]

Training

We provided several examples to train Lite Transformer with this repo:

To train Lite Transformer on WMT'14 En-Fr (with 8 GPUs), one can run

python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml

To train Lite Transformer with less GPUs, e.g. 4 GPUS, one can run

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32

In general, to train a model, one can run

python train.py [path to the data binary] --configs [path to config file] [override options]

Note that --update-freq should be adjusted according to the GPU numbers (16 for 8 GPUs, 32 for 4 GPUs).

Distributed Training (optional)

To train Lite Transformer in distributed manner. For example on two GPU nodes with totally 16 GPUs.

# On host1
python -m torch.distributed.launch \
        --nproc_per_node=8 \
        --nnodes=2 --node_rank=0 \
        --master_addr=host1 --master_port=8080 \
        train.py data/binary/wmt14_en_fr \
        --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \
        --distributed-no-spawn \
        --update-freq 8
# On host2
python -m torch.distributed.launch \
        --nproc_per_node=8 \
        --nnodes=2 --node_rank=1 \
        --master_addr=host1 --master_port=8080 \
        train.py data/binary/wmt14_en_fr \
        --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \
        --distributed-no-spawn \
        --update-freq 8

Models

We provide the checkpoints for our Lite Transformer reported in the paper: | Dataset | #Mult-Adds | Test Score | Model and Test Set | |:--:|:--:|:--:|:--:| | WMT'14 En-Fr | 90M | 35.3 |download | | | 360M | 39.1 | download | | | 527M | 39.6 | download | | WMT'16 En-De | 90M | 22.5 | download | | | 360M | 25.6 | download | | | 527M | 26.5 | download| | CNN / DailyMail | 800M | 38.3 (R-L) | download| | WIKITEXT-103 | 1147M | 22.2 (PPL) | download|

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].