august 10, 2018 —
                                          
By Xuechen Li, Software Engineering Intern
OverviewEager execution simplifies the model building experience in TensorFlow, whereas graph execution can provide optimizations that make models run faster with better memory efficiency. This blog post showcases how to write TensorFlow code so that models built using eager execution with the tf.keras API can be converted to graphs and eventually deploye…
tf-nightly or tf-nightly-gpu is highly recommended).


tf.keras.Model and the forward pass by overriding the call method as in the above equations:class Residual(tf.keras.Model):
  def __init__(self, filters):
    super(Residual, self).__init__()
    self.f = ResidualInner(filters=filters, strides=(1, 1))
    self.g = ResidualInner(filters=filters, strides=(1, 1))
  def call(self, x, training=True):
    x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
    f_x2 = self.f(x2, training=training)
    y1 = f_x2 + x1
    g_y1 = self.g(y1, training=training)
    y2 = g_y1 + x2
    return tf.concat([y1, y2], axis=self.axis)training argument here is used to determine the state of batch normalization. With eager execution enabled, the running averages of batch norm are updated automatically when training=True. When executing the equivalent graph, the batch norm updates need to be manually fetched with the method get_updates_for.tf.GradientTape as a context manager to trace gradients only where needed:  def backward_grads(self, y, dy, training=True):
    dy1, dy2 = dy
    y1, y2 = y
    with tf.GradientTape() as gtape:
      gtape.watch(y1)
      gy1 = self.g(y1, training=training)
    grads_combined = gtape.gradient(
        gy1, [y1] + self.g.trainable_variables, output_gradients=dy2)
    dg = grads_combined[1:]
    dx1 = dy1 + grads_combined[0]
    x2 = y2 - gy1
    with tf.GradientTape() as ftape:
      ftape.watch(x2)
      fx2 = self.f(x2, training=training)
    grads_combined = ftape.gradient(
        fx2, [x2] + self.f.trainable_variables, output_gradients=dx1)
    df = grads_combined[1:]
    dx2 = dy2 + grads_combined[0]
    x1 = y1 - fx2
    x = x1, x2
    dx = dx1, dx2
    grads = df + dg
    return x, dx, gradstape.gradient(y, x) computes the gradient of y with respect to x. We can also use the argument output_gradients to explicitly apply the chain rule.block = Residual()
x = tf.random_normal(shape=(N, C, H, W))
dy = tf.random_normal(shape=(N, C, H, W))
with tf.GradientTape() as tape:
  tape.watch(x)
  y = block(x)
# Compute true grads
dx_true = tape.gradient(y, x, output_gradients=dy)
# Compute grads from reconstruction
dx, _ = block.backward_grads(x, y, dy)
# Check whether the difference is below a certain threshold
thres = 1e-6
diff_abs = tf.reshape(abs(dx - dx_true), [-1])
assert all(diff_abs < thres)dx_true is the gradient returned by regular backprop, whereas dx is the gradient returned by our implementation of reversible backprop. Eager execution integrates with native Python so that functions like all and abs can be directly applied to Tensors.tf.train.Checkpoint API.tf.train.Checkpoint with all the objects we want to store. This may include our model, optimizers we use, the learning rate schedule, and the global step:checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer,
        learning_rate=learning_rate, global_step=global_step)checkpoint.save(file_prefix)
checkpoint.restore(save_path)tf.contrib.eager.defun. When training a deep learning model, there are typically three major places where we can apply tf.contrib.eager.defun: 1) the forward computation, 2) the backward computation for the gradients, and 3) the application of gradients to variables. As an example, we can defun the forward pass and the gradient computation as follows:tfe = tf.contrib.eager
model.call = tfe.defun(model.call)
model.compute_gradients = tfe.defun(model.compute_gradients)def apply_gradients(optimizer, gradients, variables, global_step=None):
    optimizer.apply_gradients(
        zip(gradients, variables), global_step=global_step)
apply_gradients = tfe.defun(apply_gradients)tf.contrib.eager.defun is under active development, and applying it is an evolving technique; for more information, consult its docstring.tf.contrib.eager.defun causes the TensorFlow API calls in the Python function to build a graph instead of immediately executing operations, enabling whole program optimizations. Not all Python functions can be successfully converted to an equivalent graph, particularly those with dynamic control flow (e.g., an if or while on Tensor contents). tf.contrib.autograph is a related tool that increases the surface area of Python code that can be converted to a TensorFlow graph. As of August 2018, integration of autograph with defun was in progress.tf.data.Dataset API. We can read a TFRecords file:dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.repeat(epochs).map(parser).batch(batch_size)prefetch function and adjust num_parallel_calls.for image, label in dataset:
  logits = model(image, training=True)
  ...tf.keras API also supports graph building, the same model built using eager execution can also be used as a graph-construction function provided to an Estimator, with few changes to the code. To modify the RevNet example built in eager execution, we need only wrap the keras model in a model_fn and use it according to the tf.estimator API.def model_fn(features, labels, mode, params):
  model = RevNet(params["hyperparameters"])
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.MomentumOptimizer(learning_rate, momentum)
    logits, saved_hidden = model(features, training=True)
    grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
    with tf.control_dependencies(model.get_updates_for(features)):
      train_op = optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)input_fn required by the tf.estimator API can be defined as usual using the tf.data API, reading from TFRecords.Estimator allows the model to run on Cloud TPUs.tf.estimator.Estimator to tf.contrib.tpu.TPUEstimatortf.contrib.tpu.CrossShardOptimizertf.contrib.tpu.keras_to_tpu_model in the future.
tf.GradientTape, coupled with a simplification of the gradient computation that obviates the need for an extra forward pass, allows us to implement RevNet’s reversible backprop with a computational overhead of just 25% compared to regular backprop.
 To verify memory savings, we plot memory usage as training progresses. The blue and black curves are regular and reversible backprop respectively. The plot records 100 iterations of RevNet-104 graph-mode training on mock ImageNet data with a batch size of 128. The plot was generated by mprof while training was performed on CPU so that we can train with the same batch size in regular backprop.
To verify memory savings, we plot memory usage as training progresses. The blue and black curves are regular and reversible backprop respectively. The plot records 100 iterations of RevNet-104 graph-mode training on mock ImageNet data with a batch size of 128. The plot was generated by mprof while training was performed on CPU so that we can train with the same batch size in regular backprop.
 
tf.keras API. This simplifies the model building experience and moreover, with little extra effort, we can convert our model to estimators and deploy them on Cloud TPUs for high performance. You can find the complete code for this article here. Also, make sure to checkout other examples with eager execution. 
august 10, 2018
 —
                                  
By Xuechen Li, Software Engineering Intern
OverviewEager execution simplifies the model building experience in TensorFlow, whereas graph execution can provide optimizations that make models run faster with better memory efficiency. This blog post showcases how to write TensorFlow code so that models built using eager execution with the tf.keras API can be converted to graphs and eventually deploye…