BERT & ALBERT in a flask
This repo demonstrates how to serve both vanilla BERT and ALBERT predictions through a simple dockerized flask API. The repo contains implementation of both google-research/bert and google-research/ALBERT for multiclass classification purposes. Model implementation is based on kpe/bert-for-tf2.
The code is written to be compatible with Tensorflow 2.0, and Docker>=19. If you want to train with a GPU, you will also need to have nvidia-container-toolkit
installed. Instructions on how to install this can be found here. The project uses the Stackoverflow dataset for demonstration purposes.
NEWS
- 04-02-2020 - Now supports google-research/ALBERT. Data and model loading has been simplified and improved.
- 06-01-2020 - Codebase has been overhauled and ported to Tensorflow 2.0. The implementation is now based on kpe/bert-for-tf2.
- 25-10-2019 - Bugfixes. Code is now runnable, with a data example. Codebase still under construction, considering upgrading the project to TF 2.0
USAGE
The most common docker commands used in this project, have been wrapped in Makefile
targets for convenience. Training and API serving is executed in docker containers. The corresponding Dockerfiles are dev.Dockerfile
for training/development and api.Dockerfile
for the API. Building these can be achieved using the following targets.
make build_dev
make build_api
The development docker image uses docker volumes and boxboat/fixuid, to mount the local folder with proper permissions for easier development and debugging. The API docker image on the other hand copies the code in to the image, on build. This means you should rerun make build_api
whenever you make changes to the serving code, or train a new model you want to serve.
TRAINING
Pretrained models
A number of pretrained models from both google-research/bert and google-research/ALBERT, are supported. Below are two lists of supported BERT and ALBERT models:
BERT
- uncased_L-12_H-768_A-12
- uncased_L-24_H-1024_A-16
- cased_L-12_H-768_A-12
- cased_L-24_H-1024_A-16
- multi_cased_L-12_H-768_A-12
- wwm_uncased_L-24_H-1024_A-16
- wwm_cased_L-24_H-1024_A-16
ALBERT
The training script automatically adapts to either type of model. Simply specify which pretrained model you want to use by changing the variable model_name
in the file src/train.py
. Accepted model values are the ones listed above, however they can also be found in /src/utils/loader.py
.
Parameters
Below is a list of high level parameters, that can be adjusted in src/train.py
. Notice the default values in src/train.py
are given as an example, and should not be seen as optimal.
Name | Description |
---|---|
batch_size | Batch size used during training. |
max_seq_len | Upper bound on the length of input sequences. |
n_epochs | Total number of epochs to run. |
max_learn_rate | The upper bound on the learning rate. |
min_learn_rate | The lower bound on the learning rate. |
warmup_proportion | Fraction of total epochs to do linear warmup. |
By default the learning rate used during training, is modelled to be linearly growing towards an upper bound until a given fraction of total epochs is reached, then it decays exponentially to a lower bound. If you do not wish to use this behaviour, remove the callback learning_rate_callback
from the list of keras callbacks.
Running
To train the model run the following command:
make train TENSORBOARD_PORT=<PORT> GPU=<GPU>
Tensorboard is intialized during training in models/trained
, and the port defaults to 6006. The GPU
parameter which expects an integer (defaults to 1) specifies the GPU you want to use for training (only relevant if you have more than one GPU). This code is written for single GPU setups.
Model files and tensorboard logs are exported to a datetime stamped folder under models/trained
during training. Furthermore a confusion matrix plot is generated and saved to this directory upon finishing training, which is calculated from the model performance on the test set.
Out of memory exceptions
Depending on your GPU hardware you might experience oom exceptions. To avoid this either reduce the batch_size
or max_seq_len
and/or choose a smaller pretrained base model.
INFERENCE
To serve predictions, a minimal working example of a dockerized flask API is provided.
The API loads the latest trained model, by looking in the file models/latest_model_config.json
, which is overwritten everytime a new model is trained. Modify this file manually if you wish to use a specific model.
The API can be booted up using the command:
make start_api api_port=<PORT>
Once the API container is running, requests can be made in the following form using the assigned API port:
curl -H "Content-Type: application/json" --request POST --data '<JSON_OBJECT>' http://localhost:<PORT>/predict
The JSON object should have the following format. Notice the input is a list, making it possible to parse a list of input bodies:
{
"x": ["Should i use public or private when declaring a variable in a class?", "I get an ImportError everytime i try to import a module in my main.py script"]
}
The API does a minimum logging of each call made. The format of the logs is defined in src/app.py
. Below is an example of a request log:
{
"endpoint": "/predict",
"response": {
"model": "BERT 2020-01-06 22:06:23 4b4e44b2",
"predictions": ["java", "python"],
"probabilities": [0.9852411150932312, 0.9999822378158569]
},
"status_code": 200,
"response_time": 0.48097777366638184,
"user_ip": null,
"user_agent": "curl/7.58.0"
}
LICENSE
MIT (License file)