TensorFlow Extended (TFX): Using Apache Beam for large scale data processing
March 10, 2020
Posted by By Reza Rokni, Developer Advocate Google Cloud, on behalf of the TFX and Dataflow teams

Beam reza
TFX core mission is to allow models to be moved from research to production, creating and managing production pipelines. Many models will be built using large volumes of data, requiring multiple hosts working in parallel to serve both the processing and serving needs of your production pipelines.

Using capabilities inherited from Apache Beam, TFX's data processing framework, we will look at how a TFX pipeline, developed against a small dataset, can be scaled out for your production dataset.

Apache Beam

The origins of Apache Beam can be traced back to FlumeJava, which is the data processing framework used at Google (discussed in the FlumeJava paper (2010)). Google Flume is heavily in use today across Google internally, including the data processing framework for Google's internal TFX usage.

Google Flume was the basis for the development of Google Cloud Dataflow (released in 2015). The SDK for Dataflow was open sourced in 2016 as Apache Beam. Similar to Google's internal TFX implementation (discussed in the TFX paper (2017)) , the external version of TFX makes use of the external version of Google Flume, Apache Beam.

The Apache Beam portable API layer powers TFX libraries (for example TensorFlow Data Validation, TensorFlow Transform, and TensorFlow Model Analysis), within the context of a Directed Acyclic Graph (DAG) of execution. Apache Beam pipelines can be executed across a diverse set of execution engines, or “runners”. A comprehensive list of runners and their capabilities can be found at:


The runner, used in this blog, is Dataflow which shares a large percentage of its code with Google Flume, with further unification in progress.

Below we can see the graph created by the TFX component ExampleGen when it is run on the Dataflow runner.

Apache Beam Benefits

This freedom to choose different execution engines was an important factor in deciding to make use of Apache Beam for TFX. Development can be done on a local DirectRunner, with production workloads run on production runners. For example, the production Apache Flink runner can run in a local data center, or you can use a fully managed cloud runner like Dataflow.
By using production runners, we can make use of tens of thousands of cores, all working in parallel to carry out the computation done in TFX libraries, without changing the core code created during the development of the pipeline.

We will show this ability using two examples. First, using a core TFX library AnalyzeAndTransformDataset and finally via two TFX components ExampleGen and StatisticsGen.

BigQuery and Dataflow are chargeable services, please ensure you understand the cost implications before running any of the samples in this blog.

TFX Libraries

TFX pipeline components are built upon TFX libraries. For example, TensorFlow Transform, which uses Apache Beam. We will explore this library using two different Apache Beam runners. Initially, the local development runner DirectRunner will be used. This will be followed by some minor code modifications to run the sample with the production Dataflow runner. The DirectRunner is a lightweight runner for development purposes. It runs locally and does not require a distributed processing framework.

The example pipeline below is borrowed from the tutorial (Preprocess data (beginner)) which provides an example of how TensorFlow Transform (tf.Transform) can be used to preprocess data.

For details of the preprocessing_fn, please refer back to the tutorial. For now, we just need to know that it is transforming the data points passed into the function.

Environment used for this blog post:
virtualenv tfx-beam --python=python3
source tfx-beam/bin/activate
pip install tfx
def main():
  with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
    transformed_dataset, transform_fn = (
        (raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
  transformed_data, transformed_metadata = transformed_dataset
  print('\nRaw data:\n{}\n'.format(pprint.pformat(raw_data)))
  print('Transformed data:\n{}'.format(pprint.pformat(transformed_data)))
if __name__ == '__main__':
Apache Beam uses a special syntax to define and invoke transforms. For example, in this line:
result = pass_this | 'name this step' >> to_this_call
The method to_this_call is being invoked and passed the object called pass_this, and this operation will be referred to as name this step in a stack trace.

The example above, implicitly will make use of the local development / test runner DirectRunner. To switch from the local DirectRunner to Dataflow, first we need to wrap the beam_impl.Context within a beam.Pipeline. This gives the ability to pass in arguments, for example "--runner". For a quick local test you can run the sample below with --runner set to DirectRunner.
import apache_beam as beam


def main():
     with beam.Pipeline(argv=argv) as p:
       # Ignore the warnings
       with beam_impl.Context(temp_dir=tempfile.mkdtemp()):  
         input = p | beam.Create(raw_data)  
         transformed_dataset, transform_fn = (  
             (input, raw_data_metadata)
            | beam_impl.AnalyzeAndTransformDataset(preprocessing_fn))
         transformed_dataset[0] |"Print Transformed Dataset" >>  beam.Map(print)
if __name__ == '__main__':
Next, we will switch to using the Dataflow Runner. Since Dataflow is a fully managed runner working on Google Cloud, we will need to provide the pipeline with some environmental information. This includes the Google Cloud project and locations for the temporary files used by the pipeline.

You must set the correct permissions to submit a pipeline job to the Dataflow service.
Read more information on authentication at: https://cloud.google.com/dataflow/docs/concepts/security-and-permissions

# Setup our Environment

## The location of Input / Output between various stages ( TFX Components )
## This will also be the location for the Metadata 

### Can be used when running the pipeline locally

### In production you want the input and output to be stored on non-local location



# Will need setup.py to make this work with Dataflow
# import setuptools
# setuptools.setup(
#   name='demo',
#   version='0.0',
#   install_requires=['tfx==0.21.1'],
#   packages=setuptools.find_packages(),)

SETUP_FILE = "./setup.py"

def main():
    with beam.Pipeline(argv=argv) as p:
        with beam_impl.Context(temp_dir=GOOGLE_CLOUD_TEMP_LOCATION):
            input = p | beam.Create(raw_data) 
            transformed_data, transformed_metadata = (
                (input, raw_data_metadata)
                | beam_impl.AnalyzeAndTransformDataset(preprocessing_fn))

if __name__ == '__main__':
To get a feel for how much work TFX has abstracted away, below is a visual representation of the graph that the pipeline processed. We had to shrink the image to fit it all in as there are a lot of transforms!
visual representation of the graph that the TFX pipeline processed

Using TFX Components with Beam

Next, let's make use of some TFX components , which are composed from the TFX libraries discussed above. We will use ExampleGen to ingest the data and StatisticsGen which generates descriptive statistics on the data.


The ExampleGen TFX Pipeline component ingests data into TFX pipelines. It consumes external files/services to generate Examples which will be read by other TFX components. It also splits the data into training and evaluation splits, or additional splits if required, and optionally shuffles the dataset. The process is listed below:
  1. Split data into training and evaluation sets (by default, 2/3 training + 1/3 eval)
  2. Convert data into the tf.Example format
  3. Copy data into the _tfx_root directory for other components to access, for other components to access
BigQueryExampleGen allows us to directly query data in BigQuery.
def createExampleGen(query: Text):
    # Output 2 splits: train:eval=3:1.
    output = example_gen_pb2.Output(
                                 name='train', hash_buckets=3),
                 example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
    return BigQueryExampleGen(query=query, output_config=output)
As well as the SQL query to be run, BigQueryExampleGen code is also being passed configuration information via the SplitConfig object.

The data for this example comes from the public Chicago taxi trips dataset located on BigQuery's public datasets (Google Cloud's Data warehouse).


NOTE: You can find more details about BigQuery public datasets at: https://cloud.google.com/bigquery/public-data/

The query below, will extract the data in the correct format for processing by ExampleGen.
  EXTRACT(MONTH FROM trip_start_timestamp)  trip_start_month,
  EXTRACT(HOUR FROM trip_start_timestamp)  trip_start_hour,
  EXTRACT(DAYOFWEEK FROM trip_start_timestamp)  trip_start_day,
  UNIX_Millis(trip_start_timestamp) trip_start_ms_timestamp,
Note the use of LIMIT 100, which will limit the output to 100 records, allowing us to quickly test out our code for correctness.


The StatisticsGen TFX pipeline component generates descriptive statistics over both training and evaluation data, which can be used by other pipeline components. It works on the results of the previous step ExampleGen.
def createStatisticsGen(bigQueryExampleGen: BigQueryExampleGen):
    # Computes statistics over data for visualization and example validation.
    return StatisticsGen(examples=bigQueryExampleGen.outputs['examples'])
As the output of ExampleGen is required by StatisticsGen, we now have a dependency between the two steps. This producer-consumer pattern is seen throughout most production ML pipelines. To automate this pipeline, we will need something that coordinates these dependencies.

Pipeline Orchestration

One solution would be to write a simple, lightweight python script. However, what about debugging, failure modes, retries, logging, etc.?

Luckily for us, this is taken care of by TFX integrations with two pipeline orchestration engines - Kubeflow and Apache Airflow.

As well as these two orchestration engines, we can also again make use of Apache Beam as an orchestrator since the dependencies can be modeled as a DAG. So, we can use a DAG with transforms that themselves are DAG's. ... "we have to go deeper"... :-) .

The choice of which engine to use is dependent on your production needs and requirements, which is beyond the scope of this blog. For now, we will use Apache Beam for the orchestration, via TFX's BeamDagRunner. This means we are using Beam in two different roles - as an execution engine for processing data, and as an orchestrator for sequencing the TFX tasks.
# Used for setting up the orchestration 
from tfx.orchestration import pipeline
from tfx.orchestration import metadata
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
The following code creates our pipeline object ready to be executed by the BeamDagRunner.
from typing import Text
from typing import Type

def createTfxPipeline(pipeline_name: Text, pipeline_root: Text, query: Text,
                      beam_pipeline_args) -> pipeline.Pipeline:
    output = example_gen_pb2.Output(
        # Output 2 splits: train:eval=3:1.
            example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=3),
            example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)

    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen = BigQueryExampleGen(query=query, output_config=output)
    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

    return pipeline.Pipeline(
          example_gen, statistics_gen
          os.path.join(".", 'metadata', pipeline_name,'metadata.db')),
To test the code make use of the query using "LIMIT 100" via the local DirectRunner.
tfx_pipeline = createTfxPipeline(
    beam_pipeline_args=                               {
You can see the results produced by using tfdv with the output to LOCAL_PIPELINE_ROOT;
import os
import tensorflow_data_validation as tfdv

stats = tfdv.load_statistics(os.path.join(LOCAL_PIPELINE_ROOT,"StatisticsGen","statistics","","train","stats_tfrecord"))
That works fine for one hundred records, but what if the goal was to process all 187,002,0025 rows in the dataset? For this, the pipeline is switched from the DirectRunner to the production Dataflow runner. A few extra environment parameters are also set, for example the Google Cloud project to run the pipeline in.
tfx_pipeline = createTfxPipeline(
The BeamDagRunner takes care of submitting ExampleGen and StatisticsGen as separate pipelines, ensuring ExampleGen completed successfully first before starting StatisticsGen. The Dataflow service automatically takes care of spinning up workers, autoscaling, retries in the event of worker failure, centralized logging, and monitoring. Autoscaling is based on various signals including throughput rate, illustrated below; The Dataflow monitoring console shows us various metrics about the pipeline, for example, the CPU utilization of the workers. Below we see the utilization of machines as they come on-line, consistently high with most workers over 90%: Apache Beam supports custom counters, which allows developers to create metrics from within their pipelines. The TFX team has used this to create useful information counters for the various components. Below we can see some of the counters recorded during the StatisticsGen run. Filtering for the key word "num_*_feature", there were roughly a billion integers and float features values.


In this blog, we showed how TFX's use of Apache Beam lets you switch from a development environment to production infrastructure without having to change the core code. We started with the TFX libraries and moved to a pipeline with two core TFX pipeline components ExampleGen and StatisticsGen.

For more information

To learn more about TFX, check out the TFX website, join the TFX discussion group, read the TFX blog, watch our TFX playlist on YouTube, and subscribe to the TensorFlow channel.
Next post
TensorFlow Extended (TFX): Using Apache Beam for large scale data processing

Posted by By Reza Rokni, Developer Advocate Google Cloud, on behalf of the TFX and Dataflow teams

TFX core mission is to allow models to be moved from research to production, creating and managing production pipelines. Many models will be built using large volumes of data, requiring multiple hosts working in parallel to serve both the processing and serving needs of your production pipelines.