pytorch-transformer
This repository provides a PyTorch implementation of the Transformer model that has been introduced in the paper Attention Is All You Need (Vaswani et al. 2017).
Installation
The easiest way to install this package is via pip:
pip install git+https://github.com/phohenecker/pytorch-transformer
Usage
import transformer
model = transformer.Transformer(...)
1. Computing Predictions given a Target Sequence
This is the default behaviour of a
Transformer
,
and is implemented in its
forward
method:
predictions = model(input_seq, target_seq)
2. Evaluating the Probability of a Target Sequence
The probability of an output sequence given an input sequence under an already trained model can be evaluated by means
of the function
eval_probability
:
probabilities = transformer.eval_probability(model, input_seq, target_seq, pad_index=...)
3. Sampling an Output Sequence
Sampling a random output given an input sequence under the distribution computed by a model is realized by the function
sample_output
:
output_seq = transformer.sample_output(model, input_seq, eos_index, pad_index, max_len)
Pretraining Encoders with BERT
For pretraining the encoder part of the transformer
(i.e.,transformer.Encoder
)
with BERT (Devlin et al., 2018), the class MLMLoss
provides an
implementation of the masked language-model loss function.
A full example of how to implement pretraining with BERT can be found in
examples/bert_pretraining.py
.
References
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., Polosukhin, I. (2017). Attention Is All You Need.
Preprint at http://arxiv.org/abs/1706.03762.
Devlin, J., Chang, M.-W., Lee, K., & Toutanova, K. (2018).
BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.
Preprint at http://arxiv.org/abs/1810.04805.