8月 10, 2018 —
By Xuechen Li, Software Engineering Intern
OverviewEager execution simplifies the model building experience in TensorFlow, whereas graph execution can provide optimizations that make models run faster with better memory efficiency. This blog post showcases how to write TensorFlow code so that models built using eager execution with the tf.keras API can be converted to graphs and eventually deploye…
tf-nightly
or tf-nightly-gpu
is highly recommended).tf.keras.Model
and the forward pass by overriding the call
method as in the above equations:class Residual(tf.keras.Model):
def __init__(self, filters):
super(Residual, self).__init__()
self.f = ResidualInner(filters=filters, strides=(1, 1))
self.g = ResidualInner(filters=filters, strides=(1, 1))
def call(self, x, training=True):
x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
f_x2 = self.f(x2, training=training)
y1 = f_x2 + x1
g_y1 = self.g(y1, training=training)
y2 = g_y1 + x2
return tf.concat([y1, y2], axis=self.axis)
The training
argument here is used to determine the state of batch normalization. With eager execution enabled, the running averages of batch norm are updated automatically when training=True
. When executing the equivalent graph, the batch norm updates need to be manually fetched with the method get_updates_for
.tf.GradientTape
as a context manager to trace gradients only where needed: def backward_grads(self, y, dy, training=True):
dy1, dy2 = dy
y1, y2 = y
with tf.GradientTape() as gtape:
gtape.watch(y1)
gy1 = self.g(y1, training=training)
grads_combined = gtape.gradient(
gy1, [y1] + self.g.trainable_variables, output_gradients=dy2)
dg = grads_combined[1:]
dx1 = dy1 + grads_combined[0]
x2 = y2 - gy1
with tf.GradientTape() as ftape:
ftape.watch(x2)
fx2 = self.f(x2, training=training)
grads_combined = ftape.gradient(
fx2, [x2] + self.f.trainable_variables, output_gradients=dx1)
df = grads_combined[1:]
dx2 = dy2 + grads_combined[0]
x1 = y1 - fx2
x = x1, x2
dx = dx1, dx2
grads = df + dg
return x, dx, grads
The exact set of gradient computation can be found in Algorithm 1 of the paper (we simplified in our code the intermediate steps that use variable z1). The algorithm is designed so that within each reversible block, gradients with respect to the input and model variables are computed along with reconstructing the input, given both the output and the gradient of the loss with respect to the output. Calling tape.gradient(y, x)
computes the gradient of y with respect to x. We can also use the argument output_gradients
to explicitly apply the chain rule.block = Residual()
x = tf.random_normal(shape=(N, C, H, W))
dy = tf.random_normal(shape=(N, C, H, W))
with tf.GradientTape() as tape:
tape.watch(x)
y = block(x)
# Compute true grads
dx_true = tape.gradient(y, x, output_gradients=dy)
# Compute grads from reconstruction
dx, _ = block.backward_grads(x, y, dy)
# Check whether the difference is below a certain threshold
thres = 1e-6
diff_abs = tf.reshape(abs(dx - dx_true), [-1])
assert all(diff_abs < thres)
In the above snippet, dx_true
is the gradient returned by regular backprop, whereas dx
is the gradient returned by our implementation of reversible backprop. Eager execution integrates with native Python so that functions like all
and abs
can be directly applied to Tensors.tf.train.Checkpoint
API.tf.train.Checkpoint
with all the objects we want to store. This may include our model, optimizers we use, the learning rate schedule, and the global step:checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer,
learning_rate=learning_rate, global_step=global_step)
We can save and restore a particular trained instance as follows:checkpoint.save(file_prefix)
checkpoint.restore(save_path)
tf.contrib.eager.defun
. When training a deep learning model, there are typically three major places where we can apply tf.contrib.eager.defun
: 1) the forward computation, 2) the backward computation for the gradients, and 3) the application of gradients to variables. As an example, we can defun the forward pass and the gradient computation as follows:tfe = tf.contrib.eager
model.call = tfe.defun(model.call)
model.compute_gradients = tfe.defun(model.compute_gradients)
To defun the optimizer’s apply gradients step, we need to wrap it inside another function:def apply_gradients(optimizer, gradients, variables, global_step=None):
optimizer.apply_gradients(
zip(gradients, variables), global_step=global_step)
apply_gradients = tfe.defun(apply_gradients)
tf.contrib.eager.defun
is under active development, and applying it is an evolving technique; for more information, consult its docstring.tf.contrib.eager.defun
causes the TensorFlow API calls in the Python function to build a graph instead of immediately executing operations, enabling whole program optimizations. Not all Python functions can be successfully converted to an equivalent graph, particularly those with dynamic control flow (e.g., an if
or while
on Tensor
contents
). tf.contrib.autograph
is a related tool that increases the surface area of Python code that can be converted to a TensorFlow graph. As of August 2018, integration of autograph with defun was in progress.tf.data.Dataset
API. We can read a TFRecords file:dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.repeat(epochs).map(parser).batch(batch_size)
To improve performance we can also use the prefetch
function and adjust num_parallel_calls
.for image, label in dataset:
logits = model(image, training=True)
...
tf.keras
API also supports graph building, the same model built using eager execution can also be used as a graph-construction function provided to an Estimator
, with few changes to the code. To modify the RevNet example built in eager execution, we need only wrap the keras model in a model_fn
and use it according to the tf.estimator
API.def model_fn(features, labels, mode, params):
model = RevNet(params["hyperparameters"])
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum)
logits, saved_hidden = model(features, training=True)
grads, loss = model.compute_gradients(saved_hidden, labels, training=True)
with tf.control_dependencies(model.get_updates_for(features)):
train_op = optimizer.apply_gradients(zip(grads, model.trainable_variables))
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
The input_fn
required by the tf.estimator
API can be defined as usual using the tf.data
API, reading from TFRecords.Estimator
allows the model to run on Cloud TPUs.tf.estimator.Estimator
to tf.contrib.tpu.TPUEstimator
tf.contrib.tpu.CrossShardOptimizer
tf.contrib.tpu.keras_to_tpu_model
in the future.
tf.GradientTape
, coupled with a simplification of the gradient computation that obviates the need for an extra forward pass, allows us to implement RevNet’s reversible backprop with a computational overhead of just 25% compared to regular backprop.
tf.keras
API. This simplifies the model building experience and moreover, with little extra effort, we can convert our model to estimators and deploy them on Cloud TPUs for high performance. You can find the complete code for this article here. Also, make sure to checkout other examples with eager execution.
8月 10, 2018
—
By Xuechen Li, Software Engineering Intern
OverviewEager execution simplifies the model building experience in TensorFlow, whereas graph execution can provide optimizations that make models run faster with better memory efficiency. This blog post showcases how to write TensorFlow code so that models built using eager execution with the tf.keras API can be converted to graphs and eventually deploye…