tháng 12 09, 2020 — Posted by Arno Eigenwillig, Software Engineer and Luiz GUStavo Martins, Developer AdvocateBERT and other Transformer encoder architectures have been very successful in natural language processing (NLP) for computing vector-space representations of text, both in advancing the state of the art in academic benchmarks as well as in large-scale applications like Google Search. BERT has been available …
 
Posted by Arno Eigenwillig, Software Engineer and Luiz GUStavo Martins, Developer Advocate
BERT and other Transformer encoder architectures have been very successful in natural language processing (NLP) for computing vector-space representations of text, both in advancing the state of the art in academic benchmarks as well as in large-scale applications like Google Search. BERT has been available for TensorFlow since it was created, but originally relied on non-TensorFlow Python code to transform raw text into model inputs.
Today, we are excited to announce a more streamlined approach to using BERT built entirely in TensorFlow. This solution makes both pre-trained encoders and the matching text preprocessing models available on TensorFlow Hub. BERT in TensorFlow can now be run on text inputs with just a few lines of code:
An animation of the preprocessing model that makes it easy for you to input text into BERT (described below).
# Load BERT and the preprocessing model from TF Hub.
preprocess = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1')
encoder = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3')
# Use BERT on a batch of raw text inputs.
input = preprocess(['Batch of inputs', 'TF Hub makes BERT easy!', 'More text.'])
pooled_output = encoder(input)["pooled_output"]
print(pooled_output)
tf.Tensor(
[[-0.8384154  -0.26902363 -0.3839138  ... -0.3949695  -0.58442086  0.8058556 ]
 [-0.8223734  -0.2883956  -0.09359277 ... -0.13833837 -0.6251748   0.88950026]
 [-0.9045408  -0.37877116 -0.7714909  ... -0.5112085  -0.70791864  0.92950743]],
shape=(3, 768), dtype=float32)These encoder and preprocessing models have been built with TensorFlow Model Garden’s NLP library and exported to TensorFlow Hub in the SavedModel format. Under the hood, preprocessing uses TensorFlow ops from the TF.text library to do the tokenization of input text – allowing you to build your own TensorFlow model that goes from raw text inputs to prediction outputs without Python in the loop. This accelerates the computation, removes boilerplate code, is less error prone, and enables the serialization of the full text-to-outputs model, making BERT easier to serve in production.
To show in more detail how these models can help you, we’ve published two new tutorials:
BERT models are pre-trained on a large corpus of text (for example, an archive of Wikipedia articles) using self-supervised tasks like predicting words in a sentence from the surrounding context. This type of training allows the model to learn a powerful representation of the semantics of the text without needing labeled data. However, it also takes a significant amount of computation to train – 4 days on 16 TPUs (as reported in the 2018 BERT paper). Fortunately, after this expensive pre-training has been done once, we can efficiently reuse this rich representation for many different tasks.
TensorFlow Hub offers a variety of BERT and BERT-like models:
These models are BERT encoders. The links above take you to their documentation on TF Hub, which refers to the right preprocessing model for use with each of them.
We encourage developers to visit these model pages to learn more about the different applications targeted by each model. Thanks to their common interface, it's easy to experiment and compare the performance of different encoders on your specific task by changing the URLs of the encoder model and its preprocessing.
For each BERT encoder, there is a matching preprocessing model. It transforms raw text to the numeric input tensors expected by the encoder, using TensorFlow ops provided by the TF.text library. Unlike preprocessing with pure Python, these ops can become part of a TensorFlow model for serving directly from text inputs. Each preprocessing model from TF Hub is already configured with a vocabulary and its associated text normalization logic and needs no further set-up.
We’ve already seen the simplest way of using the preprocessing model above. Let’s look again more closely:
preprocess = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/1')
input = preprocess(["This is an amazing movie!"])
 
{'input_word_ids': <tf.Tensor: shape=(1, 128), dtype=int32, numpy=
  array([[ 101, 2023, 2003, 2019, 6429, 3185,  999,  102,    0,  ...]])>,
 'input_mask': <tf.Tensor: shape=(1, 128), dtype=int32, numpy=
  array([[   1,    1,    1,    1,    1,    1,    1,    1,    0,  ...,]])>,
 'input_type_ids': <tf.Tensor: shape=(1, 128), dtype=int32, numpy=
  array([[   0,    0,    0,    0,    0,    0,    0,    0,    0,  ...,]])>}
 
Calling preprocess() like this transforms raw text inputs into a fixed-length input sequence for the BERT encoder. You can see that it consists of a tensor input_word_ids with numerical ids for each tokenized input, including start, end and padding tokens, plus two auxiliary tensors: an input_mask (that tells non-padding from padding tokens) and input_type_ids for each token (that can distinguish multiple text segments per input, which we will discuss below). 
The same preprocessing SavedModel also offers a second, more fine-grained API, which supports putting one or two distinct text segments into one input sequence for the encoder. Let’s look at a sentence entailment task, in which BERT is used to predict if a premise entails a hypothesis or not:
text_premises = ["The fox jumped over the lazy dog.",
                 "Good day."]
tokenized_premises = preprocess.tokenize(text_premises)
 
<tf.RaggedTensor
  [[[1996], [4419], [5598], [2058], [1996], [13971], [3899], [1012]],
  [[2204], [2154], [1012]]]>
 
 
text_hypotheses = ["The dog was lazy.",  # Entailed.
                   "Axe handle!"]        # Not entailed.
tokenized_hypotheses = preprocess.tokenize(text_hypotheses)
 
<tf.RaggedTensor
  [[[1996], [3899], [2001], [13971], [1012]],
  [[12946], [5047], [999]]]>The result of each tokenization is a RaggedTensor of numeric token ids, representing each of the text inputs in full. If some pairs of premise and hypothesis are too long to fit within the seq_length for BERT inputs in the next step, you can do additional preprocessing here, such as trimming the text segment or splitting it into multiple encoder inputs.
The tokenized input then gets packed into a fixed-length input sequence for the BERT encoder:
encoder_inputs = preprocess.bert_pack_inputs(
   [tokenized_premises, tokenized_hypotheses],
   seq_length=18)  # Optional argument, defaults to 128.
 
{'input_word_ids': <tf.Tensor: shape=(2, 18), dtype=int32, numpy=
  array([[  101,  1996,  4419,  5598,  2058,  1996, 13971,  3899,  1012,
            102,  1996,  3899,  2001, 13971,  1012,   102,     0,     0],
         [  101,  2204,  2154,  1012,   102, 12946,  5047,   999,   102,
              0,     0,     0,     0,     0,     0,     0,     0,     0]])>,
 'input_mask': <tf.Tensor: shape=(2, 18), dtype=int32, numpy=
  array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])>,
 'input_type_ids': <tf.Tensor: shape=(2, 18), dtype=int32, numpy=
  array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0],
         [0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])>}The result of packing is the already-familiar dict of input_word_ids, input_mask and input_type_ids (which are 0 and 1 for the first and second input, respectively). All outputs have a common seq_length (128 by default). Inputs that would exceed seq_length are truncated to approximately equal sizes during packing.
TensorFlow Hub provides BERT encoder and preprocessing models as separate pieces to enable accelerated training, especially on TPUs.
Tensor Processing Units (TPUs) are Google’s custom-developed accelerator hardware that excel at large scale machine learning computations such as those required to fine-tune BERT. TPUs operate on dense Tensors and expect that variable-length data like strings has already been transformed into fixed-size Tensors by the host CPU.
The split between the BERT encoder model and its associated preprocessing model enables distributing the encoder fine-tuning computation to TPUs as part of model training, while the preprocessing model executes on the host CPU. The preprocessing computation can be run asynchronously on a dataset using tf.data.Dataset.map() with dense outputs ready to be consumed by the encoder model on the TPU. Asynchronous preprocessing like this can improve performance with other accelerators as well.
Our advanced BERT tutorial can be run in a Colab runtime that uses a TPU worker and demonstrates this end-to-end.
Using BERT and similar models in TensorFlow has just gotten simpler. TensorFlow Hub makes available a large collection of pre-trained BERT encoders and text preprocessing models that are easy to use in just a few lines of code.
Take a look at our interactive beginner and advanced tutorials to learn more about how to use the models for sentence and sentence-pair classification. Let us know what you build with these new BERT models and tag your posts with #TFHub.
We’d like to thank a number of colleagues for their contribution to this work.
The new preprocessing models have been created in collaboration with Chen Chen, Terry Huang, Mark Omernick and Rajagopal Ananthanarayanan.
Additional BERT models have been published to TF Hub on this occasion by Sebastian Ebert (Small BERTs), Le Hou and Hongkun Yu (Lambert, Talking Heads).
Mark Daoust, Josh Gordon and Elizabeth Kemp have greatly improved the presentation of the material in this post and the associated tutorials. Tom Small for the beautiful BERT animation.
 
tháng 12 09, 2020 — Posted by Arno Eigenwillig, Software Engineer and Luiz GUStavo Martins, Developer AdvocateBERT and other Transformer encoder architectures have been very successful in natural language processing (NLP) for computing vector-space representations of text, both in advancing the state of the art in academic benchmarks as well as in large-scale applications like Google Search. BERT has been available …