4月 11, 2018 —
Posted by Josh Dillon, Software Engineer; Mike Shwe, Product Manager; and Dustin Tran, Research Scientist — on behalf of the TensorFlow Probability Team
At the 2018 TensorFlow Developer Summit, we announced TensorFlow Probability: a probabilistic programming toolbox for machine learning researchers and practitioners to quickly and reliably build sophisticated models that leverage state-of-the-art…
An overview of TensorFlow Probability. The probabilistic programming toolbox provides benefits for users ranging from Data Scientists and Statisticians to all TensorFlow Users. |
tf.linalg
in core TF.tf.contrib.distributions
, tf.distributions
): A large collection of probability distributions and related statistics with batch and broadcasting semantics.tf.contrib.distributions.bijectors
): Reversible and composable transformations of random variables. Bijectors provide a rich class of transformed distributions, from classical examples like the log-normal distribution to sophisticated deep learning models such as masked autoregressive flows.tfp.edward2
): A probabilistic programming language for specifying flexible probabilistic models as programs.tfp.layers
): Neural network layers with uncertainty over the functions they represent, extending TensorFlow Layers.tfp.trainable_distributions
): Probability distributions parameterized by a single Tensor, making it easy to build neural nets that output probability distributions.tfp.mcmc
): Algorithms for approximating integrals via sampling. Includes Hamiltonian Monte Carlo, random-walk Metropolis-Hastings, and the ability to build custom transition kernels.tfp.vi
): Algorithms for approximating integrals via optimization.tfp.optimizer
): Stochastic optimization methods, extending TensorFlow Optimizers. Includes Stochastic Gradient Langevin Dynamics.tfp.monte_carlo
): Tools for computing Monte Carlo expectations.tfp.edward2
), which extends Edward. The program below reifies the model in terms of its generative process.import tensorflow as tf
from tensorflow_probability import edward2 as ed
def model(features):
# Set up fixed effects and other parameters.
intercept = tf.get_variable("intercept", [])
service_effects = tf.get_variable("service_effects", [])
student_stddev_unconstrained = tf.get_variable(
"student_stddev_pre", [])
instructor_stddev_unconstrained = tf.get_variable(
"instructor_stddev_pre", [])
# Set up random effects.
student_effects = ed.MultivariateNormalDiag(
loc=tf.zeros(num_students),
scale_identity_multiplier=tf.exp(
student_stddev_unconstrained),
name="student_effects")
instructor_effects = ed.MultivariateNormalDiag(
loc=tf.zeros(num_instructors),
scale_identity_multiplier=tf.exp(
instructor_stddev_unconstrained),
name="instructor_effects")
# Set up likelihood given fixed and random effects.
ratings = ed.Normal(
loc=(service_effects * features["service"] +
tf.gather(student_effects, features["students"]) +
tf.gather(instructor_effects, features["instructors"]) +
intercept),
scale=1.,
name="ratings")
return ratings
The model takes as input a features dictionary of “service”, “students”, and “instructors”; they are vectors where each element describes an individual course. The model regresses on these inputs, posits latent random variables, and returns a distribution over the courses’ evaluation ratings. TensorFlow session runs on this output will return a generation of the ratings.import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.distributions.bijectors
# Example: Log-Normal Distribution
log_normal = tfd.TransformedDistribution(
distribution=tfd.Normal(loc=0., scale=1.),
bijector=tfb.Exp())
# Example: Kumaraswamy Distribution
Kumaraswamy = tfd.TransformedDistribution(
distribution=tfd.Uniform(low=0., high=1.),
bijector=tfb.Kumaraswamy(
concentration1=2.,
concentration0=2.))
# Example: Masked Autoregressive Flow
# https://arxiv.org/abs/1705.07057
shift_and_log_scale_fn = tfb.masked_autoregressive_default_template(
hidden_layers=[512, 512],
event_shape=[28*28])
maf = tfd.TransformedDistribution(
distribution=tfd.Normal(loc=0., scale=1.),
bijector=tfb.MaskedAutoregressiveFlow(
shift_and_log_scale_fn=shift_and_log_scale_fn))
The “Gaussian Copula” creates a few custom Bijectors and then shows how to easily build several different copulas. For more background on distributions, see “Understanding TensorFlow Distributions Shapes.” It describes how to manage shapes for sampling, batch training, and modeling events.import tensorflow as tf
import tensorflow_probability as tfp
# Assumes user supplies `likelihood`, `prior`, `surrogate_posterior`
# functions and that each returns a
# tf.distribution.Distribution-like object.
elbo_loss = tfp.vi.monte_carlo_csiszar_f_divergence(
f=tfp.vi.kl_reverse, # Equivalent to "Evidence Lower BOund"
p_log_prob=lambda z: likelihood(z).log_prob(x) + prior().log_prob(z),
q=surrogate_posterior(x),
num_draws=1)
train = tf.train.AdamOptimizer(
learning_rate=0.01).minimize(elbo_loss)
tfp.layers
).import tensorflow as tf
import tensorflow_probability as tfp
model = tf.keras.Sequential([
tf.keras.layers.Reshape([32, 32, 3]),
tfp.layers.Convolution2DFlipout(
64, kernel_size=5, padding='SAME', activation=tf.nn.relu),
tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
strides=[2, 2],
padding='SAME'),
tf.keras.layers.Reshape([16 * 16 * 64]),
tfp.layers.DenseFlipout(10)
])
logits = model(features)
neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
kl = sum(model.get_losses_for(inputs=None))
loss = neg_log_likelihood + kl
train_op = tf.train.AdamOptimizer().minimize(loss)
The model object composes neural net layers on an input tensor, and it performs stochastic forward passes with respect to probabilistic convolutional layer and probabilistic densely-connected layer. The function returns an output tensor with shape given by the batch size and 10 values. Each row of this tensor represents the logits (unconstrained probability values) that each data point belongs to one of the 10 classes.tfp.layers
can also be used with eager execution using the tf.keras.Model class.class MNISTModel(tf.keras.Model):
def __init__(self):
super(MNISTModel, self).__init__()
self.dense1 = tfp.layers.DenseFlipout(units=10)
self.dense2 = tfp.layers.DenseFlipout(units=10)
def call(self, input):
"""Run the model."""
result = self.dense1(input)
result = self.dense2(result)
# reuse variables from dense2 layer
result = self.dense2(result)
return result
model = MNISTModel()
pip install --user --upgrade tfp-nightly
For all the code and details, check out github.com/tensorflow/probability. We’re excited to collaborate with you via GitHub, whether you’re a user or contributor!
4月 11, 2018
—
Posted by Josh Dillon, Software Engineer; Mike Shwe, Product Manager; and Dustin Tran, Research Scientist — on behalf of the TensorFlow Probability Team
At the 2018 TensorFlow Developer Summit, we announced TensorFlow Probability: a probabilistic programming toolbox for machine learning researchers and practitioners to quickly and reliably build sophisticated models that leverage state-of-the-art…