Ağustos 23, 2019 —
                                          
Posted by Bryan Cutler
Apache Arrow enables the means for high-performance data exchange with TensorFlow that is both standardized and optimized for analytics and machine learning. The Arrow datasets from TensorFlow I/O provide a way to bring Arrow data directly into TensorFlow tf.data that will work with existing input pipelines and tf.data.Dataset APIs.
This blog will cover the different Arrow d…
tf.data that will work with existing input pipelines and tf.data.Dataset APIs.tf.data.Dataset, so they leverage the same APIs to integrate with tf.data pipelines and can be used as input to tf.keras. Currently, TensorFlow I/O offers 3 varieties of Arrow datasets. By name they are: ArrowDataset, ArrowFeatherDataset , and ArrowStreamDataset. All of these are fed by the same underlying Arrow data which has two important aspects: the data is structured and batched.batch_size with an optional batch_mode that is one of “keep_remainder, drop_remainder, or auto.” The modes “keep_remainder” and “drop_remainder” control what happens when the end of the Dataset results in a partial batch that is less than the batch_size. When using the mode “auto,” the batch size is set automatically to the size of the incoming Arrow record batch and the batch_size option does not need to be set.batch_size here (or using “auto” mode) is more efficient that using tf.data.Dataset.batch() because Arrow can natively create batches of data and use them to efficiently convert the batched data into tensors.import numpy as np
import pandas as pd
data = {'label': np.random.binomial(1, 0.5, 10)}
data['x0'] = np.random.randn(10) + 5 * data['label']
data['x1'] = np.random.randn(10) + 5 * data['label']
df = pd.DataFrame(data)
print(df.head())
#   label        x0        x1
#0      1  5.241089  6.231621
#1      0  0.527365  0.966182ArrowDataset works with Arrow data that is already loaded into memory. Because all data must fit in memory, it is only recommended to use this method on small datasets and is most useful to quickly load data if it fits the memory constraints. With TensorFlow in graph mode, the data will need to be serialized in the operation, which can cause a spike memory usage. However, with TensorFlow in eager mode and running in a local process the data is automatically exchanged from Python to the C++ kernel with zero-copy. Let’s try this out using our sample DataFrame from above as input:import tensorflow_io.arrow as arrow_io
ds = arrow_io.ArrowDataset.from_pandas(
    df,
    batch_size=2,
    preserve_index=False)
# Make an iterator to the dataset
ds_iter = iter(ds)
# Print the first batch
print(next(ds_iter))
#(,
# )  from_pandas takes the Pandas DataFrame as the first argument, the batch_size is set to 2, and the DataFrame index column is omitted by setting preserve_index to False. The output types and shapes of the dataset can automatically be inferred from the Pandas DataFrame schema.batch_size of 2 was used, the output shape of each tensor is (2,).ArrowFeatherDataset can load a set of files in Arrow Feather format. Feather is a light-weight file format that provides a simple and efficient way to write Pandas DataFrames to disk, see the Arrow Feather Format docs for more information. It is currently limited to primitive scalar data, but after Arrow 1.0.0 is released, it is planned to have full support for Arrow data and also interop with R DataFrames.import tensorflow_io.arrow as arrow_io
from pyarrow.feather import write_feather
# Write the Pandas DataFrame to a Feather file
write_feather(df, '/path/to/df.feather')
# Create the dataset with one or more filenames
ds = arrow_io.ArrowFeatherDataset(
    ['/path/to/df.feather'],
    columns=(0, 1, 2),
    output_types=(tf.int64, tf.float64, tf.float64),
    output_shapes=([], [], []))
# Iterate over each row of each file
for record in ds:
   label, x0, x1 = record
   # use label and feature tensorscolumns allows for selecting certain columns by index. Finally, the output_types and output_shapes are given. Alternatively, an Arrow schema can be used with the alternate constructor ArrowFeatherDataset.from_schema which will automatically infer the type of shape of the tensors.ArrowStreamDataset is used to connect to one or more endpoints that are serving Arrow record batches in the Arrow stream format. See the_Arrow_stream_docs for more on the stream format. Streaming batches is an excellent way to iterate over a large dataset, both local or remote, that might not fit entirely into memory. While streaming, the batch size can be used to limit memory usage. Currently supported endpoints are a POSIX IPv4 socket with endpoint : tcp://: , a Unix Domain Socket with endpoint unix://, and STDIN with endpoint fd://0 or fd://-.batch_size, convert the slices to Arrow record batches, and serve as a stream over a local socket.import tensorflow_io.arrow as arrow_io
ds = arrow_io.ArrowStreamDataset.from_pandas(
    df,
    batch_size=2,
    preserve_index=False)ArrowDataset example above, but since it is being chunked into batches and served as a stream, memory usage is much less which allows for working with very large DataFrames in memory. The constructor will also accept a sequence or iterator of DataFrames as long as the schema is the same.def model_fit(ds):
  """Create and fit a Keras logistic regression model."""
  
  # Build the Keras model
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(1, input_shape=(2,),
            activation='sigmoid'))
  model.compile(optimizer='sgd', loss='mean_squared_error',
                metrics=['accuracy'])
  # Fit the model on the given dataset
  model.fit(ds, epochs=5, shuffle=False)
  return modeldef read_and_process(filename):
  """Read the given CSV file and yield processed Arrow batches."""
  # Read a CSV file into an Arrow Table with threading enabled and
  # set block_size in bytes to break the file into chunks for granularity,
  # which determines the number of batches in the resulting pyarrow.Table
  opts = pyarrow.csv.ReadOptions(use_threads=True, block_size=4096)
  table = pyarrow.csv.read_csv(filename, opts)
  # Fit the feature transform
  df = table.to_pandas()
  scaler = StandardScaler().fit(df[['x0', 'x1']])
  # Iterate over batches in the pyarrow.Table and apply processing
  for batch in table.to_batches():
    df = batch.to_pandas()
    # Process the batch and apply feature transform
    X_scaled = scaler.transform(df[['x0', 'x1']])
    df_scaled = pd.DataFrame({'label': df['label'], 
                              'x0': X_scaled[:, 0],
                              'x1': X_scaled[:, 1]})
    batch_scaled = pa.RecordBatch.from_pandas(df_scaled, preserve_index=False)
    
    yield batch_scaledblock_size option is set in the pyarrow CSV reader to break the file up into chunks to control thread granularity and produce multiple record batches. Then the batches are iterated over to perform feature transform by first converting the features to a Pandas DataFrame, processing the data, then converting the result back into an Arrow record batch. Each of these conversion steps is done very efficiently on batches, so the cost is minimal to go to/from Pandas. Finally, instead of returning all batches at once, calling yield with the processed batch will return a generator, which is allows it to be used as an iterator for the input stream.ArrowStreamDataset will be used which has a constructor that accepts a record batch iterator as input and can use the generator returned from the previous function read_and_process().def make_local_dataset(filename):
  """Make a TensorFlow Arrow Dataset that reads from a local CSV file."""
 
  # Read the local file and get a record batch iterator
  batch_iter = read_and_process(filename)
  
  # Create the Arrow Dataset as a stream from local iterator of record batches
  ds = arrow_io.ArrowStreamDataset.from_record_batches(
    batch_iter,
    output_types=(tf.int64, tf.float64, tf.float64),
    batch_mode='auto',
    record_batch_iter_factory=partial(read_and_process, filename))
  # Map the dataset to combine feature columns to single tensor
  ds = ds.map(lambda l, x0, x1: (tf.stack([x0, x1], axis=1), l))
  return dsArrowStreamDataset.from_record_batches takes in the record batch iterator, the output_type definitions, and a batch_mode. The batch mode used is ‘auto’ which will automatically create batches of tensors with a batch size equal to that of the incoming Arrow record batches. This is useful since the input can control the original record batch size, but a different batch_size could also be specified.record_batch_iter_factory is to specify a function that will initialize the record batch iterator so that it could be consumed multiple times during the training epochs. The final line adds a call to tf.data.Dataset.map() that will stack the feature columns into a single tensor output.ds = make_local_dataset(filename)
model = model_fit(ds)
print("Fit model with weights: {}".format(model.get_weights()))
# Fit model with weights:
# [array([[0.7793554 ], [0.61216295]], dtype=float32),
#  array([0.03328196], dtype=float32)]def read_and_process_dir(directory):
  """Read a directory of CSV files and yield processed Arrow batches."""
  for f in os.listdir(directory):
    if f.endswith(".csv"):
      filename = os.path.join(directory, f)
      for batch in read_and_process(filename):
        yield batchdef serve_csv_data(ip_addr, port_num, directory):
  """
  Create a socket and serve Arrow record batches as a stream read from the
  given directory containing CVS files.
  """
  # Create the socket
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  sock.bind((ip_addr, port_num))
  sock.listen(1)
  # Serve forever, each client will get one iteration over data
  while True:
    conn, _ = sock.accept()
    outfile = conn.makefile(mode='wb')
    writer = None
    try:
      # Read directory and iterate over each batch in each file
      batch_iter = read_and_process_dir(directory)
      for batch in batch_iter:
        # Initialize the pyarrow writer on first batch
        if writer is None:
          writer = pa.RecordBatchStreamWriter(outfile, batch.schema)
        # Write the batch to the client stream
        writer.write_batch(batch)
    # Cleanup client connection
    finally:
      if writer is not None:
        writer.close()
      outfile.close()
      conn.close()
  sock.close()RecordBatchStreamWriter is being used that will write a sequence of Arrow record batches to an output stream — a TCP socket in this case.ArrowStreamDataset will again be used but instead of from_record_batches, the constructor will be passed an endpoint for our remote server. Internally, the dataset kernel will create a client connection and begin to read record batches over the socket.def make_remote_dataset(endpoint):
  """Make a TensorFlow Arrow Dataset that reads from a remote Arrow stream."""
  # Create the Arrow Dataset from a remote host serving a stream
  ds = arrow_io.ArrowStreamDataset(
      [endpoint],
      columns=(0, 1, 2),
      output_types=(tf.int64, tf.float64, tf.float64),
      batch_mode='auto')
  # Map the dataset to combine feature columns to single tensor
  ds = ds.map(lambda l, x0, x1: (tf.stack([x0, x1], axis=1), l))
  return dsrecord_batch_iter_factory argument since the server function will already repeat the dataset after the client disconnects, then reconnects again.serve_csv_data() could be executed in one or more remote processes and model_fit() in a separate process. A complete example showing both local and remote training can be found in the gist arrow_model_training_example. 
Ağustos 23, 2019
 —
                                  
Posted by Bryan Cutler
Apache Arrow enables the means for high-performance data exchange with TensorFlow that is both standardized and optimized for analytics and machine learning. The Arrow datasets from TensorFlow I/O provide a way to bring Arrow data directly into TensorFlow tf.data that will work with existing input pipelines and tf.data.Dataset APIs.
This blog will cover the different Arrow d…