Machine Translation with Transformer
====================================
In this notebook, we will show how to train Transformer introduced in
[1] and evaluate the pretrained model using GluonNLP. The model is both
more accurate and lighter to train than previous seq2seq models. We will
together go through:
1) Use the state-of-the-art pretrained Transformer model: we will
evaluate the pretrained SOTA Transformer model and translate a few
sentences ourselves with the ``BeamSearchTranslator`` using the SOTA
model;
2) Train the Transformer yourself: including loading and processing
dataset, define the Transformer model, write train script and
evaluate the trained model. Note that in order to obtain the
state-of-the-art results on WMT 2014 English-German dataset, it will
take around 1 day to have the model. In order to let you run through
the Transformer quickly, we suggest you to start with the ``TOY``
dataset sampled from the WMT dataset (by default in this notebook).
Preparation
-----------
Load MXNet and GluonNLP
~~~~~~~~~~~~~~~~~~~~~~~
.. code:: python
import warnings
warnings.filterwarnings('ignore')
import random
import numpy as np
import mxnet as mx
from mxnet import gluon
import gluonnlp as nlp
Set Environment
~~~~~~~~~~~~~~~
.. code:: python
np.random.seed(100)
random.seed(100)
mx.random.seed(10000)
ctx = mx.gpu(0)
Use the SOTA Pretrained Transformer model
-----------------------------------------
In this subsection, we first load the SOTA Transformer model in GluonNLP
model zoo; and secondly we load the full WMT 2014 English-German test
dataset; and finally evaluate the model.
Get the SOTA Transformer
~~~~~~~~~~~~~~~~~~~~~~~~
Next, we load the pretrained SOTA Transformer using the model API in
GluonNLP. In this way, we can easily get access to the SOTA machine
translation model and use it in your own application.
.. code:: python
import nmt
wmt_model_name = 'transformer_en_de_512'
wmt_transformer_model, wmt_src_vocab, wmt_tgt_vocab = \
nmt.transformer.get_model(wmt_model_name,
dataset_name='WMT2014',
pretrained=True,
ctx=ctx)
print(wmt_src_vocab)
print(wmt_tgt_vocab)
The Transformer model architecture is shown as below:
.. raw:: html