3월 08, 2019 —
                                          
Posted by Ian Fischer, Alex Alemi, Joshua V. Dillon, and the TFP Team
At the 2019 TensorFlow Developer Summit, we announced TensorFlow Probability (TFP) Layers. In that presentation, we showed how to build a powerful regression model in very few lines of code. Here, we will show how easy it is to make a Variational Autoencoder (VAE) using TFP Layers.
TensorFlow Probability LayersTFP Layers provide…


tfd = tfp.distributions
encoded_size = 16
prior = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1),
                        reinterpreted_batch_ndims=1)tfpl = tfp.layers
encoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=input_shape),
    tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
    tfkl.Conv2D(base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(base_depth, 5, strides=2,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(2 * base_depth, 5, strides=1,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(2 * base_depth, 5, strides=2,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(4 * encoded_size, 7, strides=1,
                padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Flatten(),
    tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size),
               activation=None),
    tfpl.MultivariateNormalTriL(
        encoded_size,
        activity_regularizer=tfpl.KLDivergenceRegularizer(prior, weight=1.0)),
])MultivariateNormalTril(), which transparently splits the activations from the final Dense() layer into the parts needed to specify both the mean and the (lower triangular) covariance matrix, the parameters of a Multivariate Normal. We used a helper, tfpl.MultivariateNormalTriL.params_size(encoded_size), to make the Dense() layer output the correct number of activations (i.e., the distribution’s parameters). Finally, we said that the distribution should contribute a “regularization” term to the final loss. Specifically, we are adding the KL divergence between the encoder and the prior to the loss, which is the KL term in the ELBO that we described above. (Fun fact: we can turn this VAE into a β-VAE simply by changing the weight argument to something other than 1!)decoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=[encoded_size]),
    tfkl.Reshape([1, 1, encoded_size]),
    tfkl.Conv2DTranspose(2 * base_depth, 7, strides=1,
                         padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(2 * base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(2 * base_depth, 5, strides=2,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=2,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose(base_depth, 5, strides=1,
                         padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(filters=1, kernel_size=5, strides=1,
                padding='same', activation=None),
    tfkl.Flatten(),
    tfpl.IndependentBernoulli(input_shape, tfd.Bernoulli.logits),
])vae = tfk.Model(inputs=encoder.inputs,
                outputs=decoder(encoder.outputs[0]))negative_log_likelihood = lambda x, rv_x: -rv_x.log_prob(x)
vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
            loss=negative_log_likelihood)rv_x because it is a random variable. This example demonstrates some of the core magic of TFP Layers — even though Keras and Tensorflow view the TFP Layers as outputting tensors, TFP Layers are actually Distribution objects. Thus, we can make our loss function be the negative log likelihood of the data given the model: -rv_x.log_prob(x).x = eval_dataset.make_one_shot_iterator().get_next()[0][:10]
xhat = vae(x)
assert isinstance(xhat, tfd.Distribution)decoder_model(encoder_model.outputs[0]))? Well, in order for Keras to view the encoder distribution as a Tensor, TFP Layers actually “reifies” the distribution as a sample from that distribution, which is just a fancy way of saying that Keras sees the Distribution object as the Tensor we would have gotten, had we called encoder_model.sample(). But, when we need to access the Distribution object directly, we can — just like we do in the loss function when we call rv_x.log_prob(x). TFP Layers provides the distribution-like and Tensor-like behaviors automatically, so you don’t need to worry about Keras getting confused.vae_model.fit():vae.fit(train_dataset,
        epochs=15,
        validation_data=eval_dataset)|  | 
| Decoder modes generated by encoding images from the MNIST test set. | 
|  | 
| Decoder modes generated by sampling from the prior. | 
 
3월 08, 2019
 —
                                  
Posted by Ian Fischer, Alex Alemi, Joshua V. Dillon, and the TFP Team
At the 2019 TensorFlow Developer Summit, we announced TensorFlow Probability (TFP) Layers. In that presentation, we showed how to build a powerful regression model in very few lines of code. Here, we will show how easy it is to make a Variational Autoencoder (VAE) using TFP Layers.
TensorFlow Probability LayersTFP Layers provide…