marzo 08, 2019 —
Posted by Ian Fischer, Alex Alemi, Joshua V. Dillon, and the TFP Team
At the 2019 TensorFlow Developer Summit, we announced TensorFlow Probability (TFP) Layers. In that presentation, we showed how to build a powerful regression model in very few lines of code. Here, we will show how easy it is to make a Variational Autoencoder (VAE) using TFP Layers.
TensorFlow Probability LayersTFP Layers provide…
tfd = tfp.distributions
encoded_size = 16
prior = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1),
reinterpreted_batch_ndims=1)
Here, we have just created a TFP independent Gaussian distribution with no learned parameters, and we have specified that our latent variable, z, will have 16 dimensions.tfpl = tfp.layers
encoder = tfk.Sequential([
tfkl.InputLayer(input_shape=input_shape),
tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
tfkl.Conv2D(base_depth, 5, strides=1,
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2D(base_depth, 5, strides=2,
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2D(2 * base_depth, 5, strides=1,
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2D(2 * base_depth, 5, strides=2,
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2D(4 * encoded_size, 7, strides=1,
padding='valid', activation=tf.nn.leaky_relu),
tfkl.Flatten(),
tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size),
activation=None),
tfpl.MultivariateNormalTriL(
encoded_size,
activity_regularizer=tfpl.KLDivergenceRegularizer(prior, weight=1.0)),
])
The encoder is just a normal Keras Sequential model, consisting of convolutions and dense layers, but the output is passed to a TFP Layer, MultivariateNormalTril()
, which transparently splits the activations from the final Dense()
layer into the parts needed to specify both the mean and the (lower triangular) covariance matrix, the parameters of a Multivariate Normal. We used a helper, tfpl.MultivariateNormalTriL.params_size(encoded_size)
, to make the Dense()
layer output the correct number of activations (i.e., the distribution’s parameters). Finally, we said that the distribution should contribute a “regularization” term to the final loss. Specifically, we are adding the KL divergence between the encoder and the prior to the loss, which is the KL term in the ELBO that we described above. (Fun fact: we can turn this VAE into a β-VAE simply by changing the weight
argument to something other than 1!)decoder = tfk.Sequential([
tfkl.InputLayer(input_shape=[encoded_size]),
tfkl.Reshape([1, 1, encoded_size]),
tfkl.Conv2DTranspose(2 * base_depth, 7, strides=1,
padding='valid', activation=tf.nn.leaky_relu),
tfkl.Conv2DTranspose(2 * base_depth, 5, strides=1,
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2DTranspose(2 * base_depth, 5, strides=2,
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2DTranspose(base_depth, 5, strides=1,
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2DTranspose(base_depth, 5, strides=2,
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2DTranspose(base_depth, 5, strides=1,
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2D(filters=1, kernel_size=5, strides=1,
padding='same', activation=None),
tfkl.Flatten(),
tfpl.IndependentBernoulli(input_shape, tfd.Bernoulli.logits),
])
The form here is essentially the same as the encoder, but now we are using transposed convolutions to take our latent representation, which is a 16 dimensional vector, and turn it into a 28 x 28 x 1 tensor. That final tensor parameterizes the pixel-independent Bernoulli distribution.vae = tfk.Model(inputs=encoder.inputs,
outputs=decoder(encoder.outputs[0]))
Our model is just a Keras Model where the outputs are defined as the composition of the encoder and the decoder. Since the encoder already added the KL term to the loss, we need to specify only the reconstruction loss (the first term of the ELBO above).negative_log_likelihood = lambda x, rv_x: -rv_x.log_prob(x)
vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
loss=negative_log_likelihood)
The loss function takes two arguments — the original input, x, and the output of the model. We call that rv_x
because it is a random variable. This example demonstrates some of the core magic of TFP Layers — even though Keras and Tensorflow view the TFP Layers as outputting tensors, TFP Layers are actually Distribution objects. Thus, we can make our loss function be the negative log likelihood of the data given the model: -rv_x.log_prob(x)
.x = eval_dataset.make_one_shot_iterator().get_next()[0][:10]
xhat = vae(x)
assert isinstance(xhat, tfd.Distribution)
But if a TFP Layer returns a Distribution, what happens when we compose the decoder with the output of the encoder: decoder_model(encoder_model.outputs[0]))
? Well, in order for Keras to view the encoder distribution as a Tensor, TFP Layers actually “reifies” the distribution as a sample from that distribution, which is just a fancy way of saying that Keras sees the Distribution object as the Tensor we would have gotten, had we called encoder_model.sample()
. But, when we need to access the Distribution object directly, we can — just like we do in the loss function when we call rv_x.log_prob(x)
. TFP Layers provides the distribution-like and Tensor-like behaviors automatically, so you don’t need to worry about Keras getting confused.vae_model.fit()
:vae.fit(train_dataset,
epochs=15,
validation_data=eval_dataset)
With this model, we are able to get an ELBO of around 115 nats (the nat is the natural logarithm equivalent of the bit — 115 nats is around 165 bits). Of course, this performance isn’t state-of-the-art, but it is easy to make any of the three components more powerful starting from this basic setup. Also, it already generates nice looking digits!Decoder modes generated by encoding images from the MNIST test set. |
Decoder modes generated by sampling from the prior. |
marzo 08, 2019
—
Posted by Ian Fischer, Alex Alemi, Joshua V. Dillon, and the TFP Team
At the 2019 TensorFlow Developer Summit, we announced TensorFlow Probability (TFP) Layers. In that presentation, we showed how to build a powerful regression model in very few lines of code. Here, we will show how easy it is to make a Variational Autoencoder (VAE) using TFP Layers.
TensorFlow Probability LayersTFP Layers provide…