All Projects → layer6ai-labs → T-Fixup

layer6ai-labs / T-Fixup

Licence: MIT License
Code for the ICML'20 paper "Improving Transformer Optimization Through Better Initialization"

Programming Languages

python
139335 projects - #7 most used programming language
Cuda
1817 projects

ICML'20 Improving Transformer Optimization Through Better Initialization

[paper][supp][video]

Authors: Xiao Shi Huang, Felipe Perez, Jimmy Ba, Maksims Volkovs

Introduction

This repository contains a full implementation of the T-Fixup algorithm implemented with the fairseq library, and includes both training and evaluation routines on the IWSLT'14 De-En dataset.

T-Fixup was used by Javier Martin and Andres Torrubia in their 3'rd place solution (out of 3395 teams) for the "Riiid Answer Correctness Prediction" Kaggle challenge. See this blogpost.

Environment

The python code is developed and tested on the following environment:

  • Python 3.7
  • Pytorch 1.2.0

Experiments on IWSLT'14 De-En and En-De datasets were run on NVIDIA V100 GPU with 32GB GPU memory; all other experiments were run on an IBM server with 160 POWER9 CPUs, 600GB RAM and 4 Tesla V100 GPUs

Dataset

The example execution script Train_IWSLT_TFixup_example.sh builds the IWSLT'14 De-En dataset; for the WMT'14 En-De and WMT'17 En-De datasets refer to the fairseq's instructions here

Running The Code

  1. ./Train_IWSLT_TFixup_example.sh
  2. (Optionally) launch tensorboard to monitor progress by tensorboard --logdir=<log_path>

This script runs the small 512-1024-4 Transformer encoder-decoder model (see paper for details) with both layer normalization and learning rate warmup removed. Starting learning rate is set to the post warmup value of 0.0005 (vs 1e-07 with warmup). By default all avialable GPUs are used, but parameters such as batchsize are set for for 1 GPU. If multiple GPUs are avaialbe, either point the script to only one GPU or adjust model parameters accordingly.

Validation Curves

Training and validation loss curves for a Transformer model trained with T-Fixup on the IWSLT'14 De-En dataset during the first 300 epochs. One epoch is around 1100 updates and we checkpoint the model after each epoch.

BLEU score, evaluated using the average of 10 checkpoints, reaches 35.73 at epochs 278-287
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].