Google Article
TensorFlow with Apache Arrow Datasets
August 23, 2019
Posted by Bryan Cutler

Apache Arrow enables the means for high-performance data exchange with TensorFlow that is both standardized and optimized for analytics and machine learning. The Arrow datasets from TensorFlow I/O provide a way to bring Arrow data directly into TensorFlow tf.data that will work with existing input pipelines and tf.data.Dataset APIs.

This blog will cover the different Arrow datasets available and how they can be used to feed common TensorFlow workloads. Starting with Pandas as a sample input source, each type of Arrow dataset will be introduced to show usage and describe best practices. Then a walk through of a complete example to demonstrate how Arrow could be used to do Keras model training first locally and then easily scale up to use a much larger remote set of data. The last section will talk about additional features to look for in the future of Arrow with TensorFlow, and ongoing work to better integrate and improve column based processing in TensorFlow I/O.

At its core, Apache Arrow is a standard format for in-memory columnar data that is designed for efficiency and interoperability between systems. The project provides many tools that make it easy to work with Arrow data in your own application and leverage the built-in optimizations whether you are using one or a combination of the library implementations in Python, C++, Java, Go, and Rust among others.

Using Apache Arrow with TensorFlow has several advantages. First, because it is a standard it will ensure type safety and data integrity no matter what the source of the data is. Second, since it is an in-memory format, it allows data exchange between systems without the need to implement serialization or a number of converters for different file formats. Finally, Arrow has been designed from the beginning to be optimized in all aspects of working with the data, from zero-copy reads to support for accelerated operations on modern hardware. So you can be sure that data is being handled as efficiently as possible, and can seamlessly integrate with different systems at any scale.

NOTE: the examples in this post use tensorflow 1.14.0 in eager mode, tensorflow_io 0.8.0 (pre-release), pyarrow 0.11.1, and sklearn 0.21.2.

Arrow Dataset Overview

The Arrow datasets are an extension of tf.data.Dataset, so they leverage the same APIs to integrate with tf.data pipelines and can be used as input to tf.keras. Currently, TensorFlow I/O offers 3 varieties of Arrow datasets. By name they are: ArrowDataset, ArrowFeatherDataset , and ArrowStreamDataset. All of these are fed by the same underlying Arrow data which has two important aspects: the data is structured and batched.

Structured Arrow Data

As mentioned previously, Arrow defines a columnar data format and a schema is used to describe each column with a name, data type, bit width, etc. This means that incoming Arrow data is self-described and the data types and shapes can be automatically inferred, ensuring that TensorFlow will be using the exact type specification no matter where the source came from.

The format of Arrow data is language agnostic and is designed to be able to transfer data across language boundaries, e.g. Java to C++, without the need for data serialization or intermediate processing.

Currently, only primitive data types are supported in TensorFlow I/O Arrow datasets, and can be scalar or array values. The latter will translate to a dense tensor vector.

Batching Arrow Natively

Arrow data can be used most efficiently when it is chunked into record batches that consist of a set of columns with an equal number of rows. The batches can then be exchanged via stream or file formats. Each Arrow Dataset supports the option batch_size with an optional batch_mode that is one of “keep_remainder, drop_remainder, or auto.” The modes “keep_remainder” and “drop_remainder” control what happens when the end of the Dataset results in a partial batch that is less than the batch_size. When using the mode “auto,” the batch size is set automatically to the size of the incoming Arrow record batch and the batch_size option does not need to be set.

Setting the batch_size here (or using “auto” mode) is more efficient that using tf.data.Dataset.batch() because Arrow can natively create batches of data and use them to efficiently convert the batched data into tensors.

Create a Sample DataFrame Input

Before taking a look at the differences between datasets, let’s first make a sample Pandas DataFrame that can used as an input source:
import numpy as np
import pandas as pd

data = {'label': np.random.binomial(1, 0.5, 10)}
data['x0'] = np.random.randn(10) + 5 * data['label']
data['x1'] = np.random.randn(10) + 5 * data['label']

df = pd.DataFrame(data)

print(df.head())
#   label        x0        x1
#0      1  5.241089  6.231621
#1      0  0.527365  0.966182
PyArrow integrates very nicely with Pandas and has many built-in capabilities of converting to and from Pandas efficiently. The Arrow datasets make use of these conversions internally, and the model training example below will show how this is done.

Create a Dataset from Arrow Memory

The ArrowDataset works with Arrow data that is already loaded into memory. Because all data must fit in memory, it is only recommended to use this method on small datasets and is most useful to quickly load data if it fits the memory constraints. With TensorFlow in graph mode, the data will need to be serialized in the operation, which can cause a spike memory usage. However, with TensorFlow in eager mode and running in a local process the data is automatically exchanged from Python to the C++ kernel with zero-copy. Let’s try this out using our sample DataFrame from above as input:
import tensorflow_io.arrow as arrow_io

ds = arrow_io.ArrowDataset.from_pandas(
    df,
    batch_size=2,
    preserve_index=False)

# Make an iterator to the dataset
ds_iter = iter(ds)

# Print the first batch
print(next(ds_iter))
#(,
# )
The dataset constructor from_pandas takes the Pandas DataFrame as the first argument, the batch_size is set to 2, and the DataFrame index column is omitted by setting preserve_index to False. The output types and shapes of the dataset can automatically be inferred from the Pandas DataFrame schema.

The output of the first batch of data is shown. The 3 columns produce 3 tensors and since a batch_size of 2 was used, the output shape of each tensor is (2,).

Loading Arrow Feather Files

The ArrowFeatherDataset can load a set of files in Arrow Feather format. Feather is a light-weight file format that provides a simple and efficient way to write Pandas DataFrames to disk, see the Arrow Feather Format docs for more information. It is currently limited to primitive scalar data, but after Arrow 1.0.0 is released, it is planned to have full support for Arrow data and also interop with R DataFrames.

This dataset will be ideal if your workload processes many DataFrames and writing to disk is desired. The Arrow Feather readers/writers are designed to maximize performance when loading/saving Arrow record batches. However, if your files are intended for long-term storage, other columnar formats, such as Apache Parquet, might be better suited. Using our sample DataFrame from above, the following code will save it as a feather file and then create a dataset with a list of filenames.
import tensorflow_io.arrow as arrow_io
from pyarrow.feather import write_feather

# Write the Pandas DataFrame to a Feather file
write_feather(df, '/path/to/df.feather')

# Create the dataset with one or more filenames
ds = arrow_io.ArrowFeatherDataset(
    ['/path/to/df.feather'],
    columns=(0, 1, 2),
    output_types=(tf.int64, tf.float64, tf.float64),
    output_shapes=([], [], []))

# Iterate over each row of each file
for record in ds:
   label, x0, x1 = record
   # use label and feature tensors
The first argument to the dataset constructor is a string or list of strings with each filename to be read. The next argument columns allows for selecting certain columns by index. Finally, the output_types and output_shapes are given. Alternatively, an Arrow schema can be used with the alternate constructor ArrowFeatherDataset.from_schema which will automatically infer the type of shape of the tensors.

Reading Streams of Arrow Batches

The ArrowStreamDataset is used to connect to one or more endpoints that are serving Arrow record batches in the Arrow stream format. See the_Arrow_stream_docs for more on the stream format. Streaming batches is an excellent way to iterate over a large dataset, both local or remote, that might not fit entirely into memory. While streaming, the batch size can be used to limit memory usage. Currently supported endpoints are a POSIX IPv4 socket with endpoint : or tcp://:, a Unix Domain Socket with endpoint unix://, and STDIN with endpoint fd://0 or fd://-.

The sample Pandas DataFrame from above can also be used as input into the dataset. This will internally make zero-copy slices of the DataFrame sized to the batch_size, convert the slices to Arrow record batches, and serve as a stream over a local socket.
import tensorflow_io.arrow as arrow_io

ds = arrow_io.ArrowStreamDataset.from_pandas(
    df,
    batch_size=2,
    preserve_index=False)
The constructor is nearly identical to the ArrowDataset example above, but since it is being chunked into batches and served as a stream, memory usage is much less which allows for working with very large DataFrames in memory. The constructor will also accept a sequence or iterator of DataFrames as long as the schema is the same.

Model Training Arrow

Perhaps the best way to show what Arrow is all about is through an example. In this section, Arrow will be used to read CSV data and train a simple classification model from it. First, the model will be tested locally with a small data sample until the results look satisfying. Then with a few additions, the dataset can be changed to read from a larger remote set of files and offload data processing from the machine running model training.

Train a Model Locally

Let’s first define our model and training step. The data will be in the form of the sample Pandas DataFrame from above, with a label and 2 feature columns. To build a classifier, a simple logistic regression model is made with Keras:
def model_fit(ds):
  """Create and fit a Keras logistic regression model."""
  
  # Build the Keras model
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(1, input_shape=(2,),
            activation='sigmoid'))
  model.compile(optimizer='sgd', loss='mean_squared_error',
                metrics=['accuracy'])

  # Fit the model on the given dataset
  model.fit(ds, epochs=5, shuffle=False)
  return model
Now to make a function to read a CSV file into Arrow data and process the batches to perform feature transform. This will use the PyArrow CSV reader which is highly optimized to read data into Arrow record batches.
def read_and_process(filename):
  """Read the given CSV file and yield processed Arrow batches."""

  # Read a CSV file into an Arrow Table with threading enabled and
  # set block_size in bytes to break the file into chunks for granularity,
  # which determines the number of batches in the resulting pyarrow.Table
  opts = pyarrow.csv.ReadOptions(use_threads=True, block_size=4096)
  table = pyarrow.csv.read_csv(filename, opts)

  # Fit the feature transform
  df = table.to_pandas()
  scaler = StandardScaler().fit(df[['x0', 'x1']])

  # Iterate over batches in the pyarrow.Table and apply processing
  for batch in table.to_batches():
    df = batch.to_pandas()

    # Process the batch and apply feature transform
    X_scaled = scaler.transform(df[['x0', 'x1']])
    df_scaled = pd.DataFrame({'label': df['label'], 
                              'x0': X_scaled[:, 0],
                              'x1': X_scaled[:, 1]})
    batch_scaled = pa.RecordBatch.from_pandas(df_scaled, preserve_index=False)
    
    yield batch_scaled
The block_size option is set in the pyarrow CSV reader to break the file up into chunks to control thread granularity and produce multiple record batches. Then the batches are iterated over to perform feature transform by first converting the features to a Pandas DataFrame, processing the data, then converting the result back into an Arrow record batch. Each of these conversion steps is done very efficiently on batches, so the cost is minimal to go to/from Pandas. Finally, instead of returning all batches at once, calling yield with the processed batch will return a generator, which is allows it to be used as an iterator for the input stream.

The next step is to build the Arrow dataset. The ArrowStreamDataset will be used which has a constructor that accepts a record batch iterator as input and can use the generator returned from the previous function read_and_process().
def make_local_dataset(filename):
  """Make a TensorFlow Arrow Dataset that reads from a local CSV file."""
 
  # Read the local file and get a record batch iterator
  batch_iter = read_and_process(filename)
  
  # Create the Arrow Dataset as a stream from local iterator of record batches
  ds = arrow_io.ArrowStreamDataset.from_record_batches(
    batch_iter,
    output_types=(tf.int64, tf.float64, tf.float64),
    batch_mode='auto',
    record_batch_iter_factory=partial(read_and_process, filename))

  # Map the dataset to combine feature columns to single tensor
  ds = ds.map(lambda l, x0, x1: (tf.stack([x0, x1], axis=1), l))
  return ds
The constructor ArrowStreamDataset.from_record_batches takes in the record batch iterator, the output_type definitions, and a batch_mode. The batch mode used is ‘auto’ which will automatically create batches of tensors with a batch size equal to that of the incoming Arrow record batches. This is useful since the input can control the original record batch size, but a different batch_size could also be specified.

The argument record_batch_iter_factory is to specify a function that will initialize the record batch iterator so that it could be consumed multiple times during the training epochs. The final line adds a call to tf.data.Dataset.map() that will stack the feature columns into a single tensor output.

Now the model can be trained on a local file by running the 2 functions above:
ds = make_local_dataset(filename)
model = model_fit(ds)

print("Fit model with weights: {}".format(model.get_weights()))
# Fit model with weights:
# [array([[0.7793554 ], [0.61216295]], dtype=float32),
#  array([0.03328196], dtype=float32)]

Scaling up to a Remote Dataset

Using what has already been done locally, making some small tweaks will enable reading from a larger remote dataset. This will also allow moving the processing to the server, so that the machine performing training does not have to load data and do feature transform. Instead, it can better utilize all resources to focus on training.

First, let’s expand our single file to walk a directory of CSV files. This can be done in a few extra lines of Python:
def read_and_process_dir(directory):
  """Read a directory of CSV files and yield processed Arrow batches."""

  for f in os.listdir(directory):
    if f.endswith(".csv"):
      filename = os.path.join(directory, f)
      for batch in read_and_process(filename):
        yield batch
Now to write our serving function which will listen over a TCP socket, read each file, process and then stream each batch to the client.
def serve_csv_data(ip_addr, port_num, directory):
  """
  Create a socket and serve Arrow record batches as a stream read from the
  given directory containing CVS files.
  """

  # Create the socket
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  sock.bind((ip_addr, port_num))
  sock.listen(1)

  # Serve forever, each client will get one iteration over data
  while True:
    conn, _ = sock.accept()
    outfile = conn.makefile(mode='wb')
    writer = None
    try:

      # Read directory and iterate over each batch in each file
      batch_iter = read_and_process_dir(directory)
      for batch in batch_iter:

        # Initialize the pyarrow writer on first batch
        if writer is None:
          writer = pa.RecordBatchStreamWriter(outfile, batch.schema)

        # Write the batch to the client stream
        writer.write_batch(batch)

    # Cleanup client connection
    finally:
      if writer is not None:
        writer.close()
      outfile.close()
      conn.close()
  sock.close()
Much of this is boilerplate for setting up a socket server. The important thing to show is that the PyArrow RecordBatchStreamWriter is being used that will write a sequence of Arrow record batches to an output stream — a TCP socket in this case.

To build the Arrow dataset, the ArrowStreamDataset will again be used but instead of from_record_batches, the constructor will be passed an endpoint for our remote server. Internally, the dataset kernel will create a client connection and begin to read record batches over the socket.
def make_remote_dataset(endpoint):
  """Make a TensorFlow Arrow Dataset that reads from a remote Arrow stream."""

  # Create the Arrow Dataset from a remote host serving a stream
  ds = arrow_io.ArrowStreamDataset(
      [endpoint],
      columns=(0, 1, 2),
      output_types=(tf.int64, tf.float64, tf.float64),
      batch_mode='auto')

  # Map the dataset to combine feature columns to single tensor
  ds = ds.map(lambda l, x0, x1: (tf.stack([x0, x1], axis=1), l))
  return ds
A single endpoint is used here, but it could also be a list with multiple endpoints. The rest is identical to our local dataset, except it is not necessary to specify the record_batch_iter_factory argument since the server function will already repeat the dataset after the client disconnects, then reconnects again.

To run training, the server function serve_csv_data() could be executed in one or more remote processes and model_fit() in a separate process. A complete example showing both local and remote training can be found in the gist arrow_model_training_example.

An important takeaway in this example is that because Arrow was used as the data format, the data was transferred from a Python server directly to the C++ client in the dataset kernel without the need to use a proprietary format or save to intermediate files. If the Python server no longer fits your needs, it can be changed to a Java server without the need to modify anything on the model training end. Likewise, if your data needs outgrow the CSV format, only the server side would need to support a different data source.

Future Work with Arrow and TensorFlow I/O

At the time of writing this, the Arrow project is working on a 1.0 release. This is significant in that it will provide compatibility guarantees throughout the 1.x life-cycle. Beyond that, adding support for tensors as a logical type in the Arrow record batch would allow blending matrices and higher-order data with standard columnar types, and better integration with many ML workflows. Track the issues ARROW-1614 and ARROW-5819 for updates or keep an eye on the project blog https://arrow.apache.org/blog/.

Upcoming improvements for Arrow with TensorFlow I/O include the addition of an Arrow Flight dataset that will provide a client to connect with an Arrow Flight server to transfer data using an Arrow-native RPC framework. This will make it a breeze to connect TensorFlow with other distributed applications on a network to exchange Arrow data. Currently, Arrow Flight is being hardened to be production ready and should be solid by the Arrow 1.0 release. Follow the issue at tensorflow/io/issues/398 for updates.

Additionally, TensorFlow I/O is working to expand columnar operations with Arrow and related datasets like Apache Parquet, HDF5 and JSON. This will enable things like split, merge, selecting columns and other operations on a mix of different columnar datasets. See the issue at tensorflow/io/issues/315 for more information.

Thanks to Maureen McElaney.

Next post
Article Image Placeholder

Posted by Bryan Cutler

Apache Arrow enables the means for high-performance data exchange with TensorFlow that is both standardized and optimized for analytics and machine learning. The Arrow datasets from TensorFlow I/O provide a way to bring Arrow data directly into TensorFlow tf.data that will work with existing input pipelines and tf.data.Dataset APIs.

This blog will cover the different Arrow d…