Industrial AI: BHGE’s Physics-based, Probabilistic Deep Learning Using TensorFlow Probability — Part 1
spalio 11, 2018
By Arun Subramaniyan, VP Data Science & Analytics at BHGE Digital

Baker Hughes, a GE Company (BHGE), is the world’s leading fullstream oil and gas company with a mission to find better ways to deliver energy to the world. The BHGE Digital team develops enterprise grade, AI-driven, SaaS solutions to improve efficiency and reduce non-productive time for the oil and gas industry. Consider mission-critical issues such as predicting the failure of a gas turbine or optimizing a large system like a petrochemical plant; these issues require building and maintaining sophisticated analytics at scale. We have developed an analytics-driven strategy to enable company-wide digital transformation for our customers.

Through years of solving the hardest problems facing the industrial world, we have learned that the most elegant and enduring solutions lie at the intersection of:

  • domain knowledge
  • traditional machine learning (ML)
  • probabilistic techniques
  • deep learning

Classes of problems that were previously intractable are now solvable by combining seemingly unrelated techniques and deploying them with modern scalable software and hardware. Our collaboration with the TensorFlow Probability (TFP) team and the Cloud ML teams at Google has accelerated our journey to develop and deploy these techniques at scale.

We intend to showcase the innovations occurring at these intersections with this series of blogs and hope to motivate a Cambrian explosion in industrial applications of probabilistic deep learning techniques. An example set of these applications that we showcased at Google Next 2018 can be viewed here.

The need for probabilistic deep learning

Physics-based (i.e., domain-based) analytics have been used successfully for decades to design and operate systems in industries as diverse as aerospace, automotive, and oil and gas. They provide a powerful way to generalize complex behavior from a few observations. Newton’s equation of motion can be used to predict precisely the movement of atoms and galaxies alike. However, innovative solutions to real-world problems require a combination of these laws with effective ways to account for unknowns. Explicitly tracking the certainty of our belief is paramount to defendable science.

For example, take the “simple” case of a projectile’s trajectory. To predict the location where a projectile would land, the only observation (“data”) required to make an accurate prediction is the velocity (speed and angle) with which the projectile is thrown. However, adding a small uncertainty (~5%) in the initial angle and wind speed introduces a large uncertainty (~200%) to where the projectile will land, as shown in the figure below.
project trajectory for wind speed figure
Making accurate predictions for even a simple system becomes difficult with uncertainty; Imagine the complexity of optimizing a large non-linear system!

Digital Twin is a construct that the BHGE Analytics team has honed over years of solving actual industrial problems. We define it as a digital representation of a physical system, continuously tuned to represent the current conditions and predict future states. Depicted in the diagram below, the three pillars required to build Digital Twins are domain knowledge, probabilistic inference, and deep learning:

  • Domain expertise (i.e., physics) combined with traditional ML allows us to solve known problems precisely, a.k.a known knowns. By traditional ML, we are referring to techniques such as polynomial regression, kernel density methods, and state-space estimation methods (e.g. Kalman filters). Predicting the length of a crack in a component under various thermo-mechanical loads is an example of domain expertise: it requires a thorough understanding of the physics of the problem combined with knowledge of the boundary conditions such as temperature and pressure. Even when the phenomenon is well understood, there are several traditional ML techniques that are required to compute coefficients that are specific to the problem (e.g. material properties).
  • Probabilistic inference provides a systematic way to quantify uncertainty, a.k.a. known unknowns. In the example above, in addition to understanding the phenomena, we also need to account for uncertainty in measurements and un-modeled factors like manufacturing variability. It is well known that there are variabilities (uncertainty) in the system that will affect the crack growth. Predicting crack length then becomes an uncertainty quantification exercise. More specifically, the material property calibration becomes a probabilistic inference problem when we add uncertainty to the measurements.
  • Deep learning and modern ML have the potential to identify and predict unknown patterns and behavior, a.k.a unknown unknowns. In the crack propagation example, even with probabilistic inference coupled with domain models, we can predict the crack propagating only under a particular set of loading conditions with known uncertainties. Consider an anomaly detection model based on an auto-encoder that monitors the loading conditions along with a variety of other conditions on the equipment. This deep learning model can catch anomalous conditions that the physics-based model would not pick up. The reason we denote this prediction as an unknown unknown is because we don’t have any data on the anomaly to train the model. Though the deep learning model trains on data from normal operation (and some abnormalities for validation), it catches any deviation from the normal behavior that was previously not observed. This is one of many examples where deep learning models come in handy. One can argue that in some cases, a simple anomaly detection model based on traditional techniques such as a PCA reconstruction error can also do the trick. In our experience, simpler techniques can provide similar or sometimes better performance than deep learning models when known process characteristics align with the simplifying assumptions of these methods, leading to known anomalies. Deep learning models really shine when the anomalies are not known a priori — when you’ve not seen a particular failure mode before.

domain models
In this article, we will focus on probabilistic inference with domain models for predicting known unknowns. We will demonstrate the power of Bayesian calibration with a crack propagation problem formulated as a physics-based probabilistic inference model.

Probabilistic deep learning allows us to leverage all of the capabilities highlighted above in a “self-learning” package. We will address probabilistic deep learning with TFP in our subsequent blogs, which will cover model discrepancy, anomaly detection, missing data estimation and time series forecasting.

Predicting life of components: A probabilistic crack propagation example

Predicting the life of a component that is prone to cracking is an age-old problem that has been studied ad nauseam by the fracture mechanics community. Crack propagation models are at the core of prognostics and health management (PHM) solutions for engineering systems. The aptly titled book Prognostics and Health Management of Engineering Systems: An Introduction provides a great example of how real world data can be used to calibrate engineering models. With the example below, we would like to motivate the use of “hybrid models” that combine probabilistic learning techniques and engineering domain models.

The phenomenon of fatigue crack propagation can be modeled with Paris’ law. Paris’ law relates the rate of crack growth (da/dN) to the stress intensity factor (ΔK=Δσ√𝜋a) through the equation below:
Paris' law equation
where a is the crack length, N is the number of loading cycles, σ is the stress, and (C, m) are material properties.

Integrating Paris’ law for a specific geometry and loading configuration, we arrive at the analytical formulation for the size of a crack as a function of the loading cycle as shown below:
Integrating Paris’ law for a specific geometry
where aₒ is the initial crack length. By knowing aₒ and the factor Δσ√𝜋 on a given application, one can use Equation 2 to calculate the size of the crack in the future, given that the application will accumulate N cycles with time. This forecasted crack length can be compared against thresholds for safe operation. For instance, if the forecasted crack length goes above a geometric limit (e.g. half the thickness of a component), it is time for a repair.

The parameters C and m need to be calibrated for each physical component, given crack length a vs loading cycles N data. In other words, given the measured crack in the field, we wish to infer C and m so we can then estimate how many cycles we can safely operate under, before the crack length becomes large enough to risk component failure.

In this example, we will use the sample dataset from the PHM book by Kim, An and Choi, demonstrating a probabilistic calibration of C and m using TFP. At BHGE Digital, we leverage our Depend-on-Docker project for automating analytics development. A sample of such automation is available here, with the complete code of the example below.

Pictured in the diagram below, we have the dataset provided in the Table 4.2 of the PHM book by Kim, An and Choi. As is typical for most crack propagation datasets, the underlying trends are not quite obvious from the observed data points.
dataset graph
For Bayesian calibration, we need to define the prior distributions for the calibration variables. In a real application, these priors can be informed by a subject matter expert. For this example, we will assume that both C and m are Gaussian and independent:
Bayesian calibration
In TFP, we can encode this information as follows:
prio_par_logC = [-23., 1.1] # [location, scale] for Normal Prior
prio_par_m = [4., 0.2] # [location, scale] for Normal Prior
rv_logC = tfd.Normal(loc=0., scale=1., name='logC_norm')
rv_m = tfd.Normal(loc=0., scale=1., name='m_norm')
We have defined external parameters and standard Normal distributions for both variables, just to sample from a normalized space. Therefore, we will need to de-normalize both random variables when computing the crack model.

Now we define the joint log probability for the random variables being calibrated and the associated crack model defined by Equation 2:
def joint_log_prob(cycles, observations, y0, logC_norm, m_norm):
    # Joint logProbability function for both random variables and observations.
    # Some constants
    dsig = 75.
    B = tf.constant(dsig * np.pi**0.5, tf.float32)
    # Computing m and logC on original space
    logC = logC_norm * prio_par_logC[1]**0.5+ prio_par_logC[0] # 
    m = m_norm * prio_par_m[1]**0.5 + prio_par_m[0]

    # Crack Propagation model
    crack_model =(cycles * tf.exp(logC) * (1 - m / 2.) * B**m + y0**(1-m / 2.))**(2. / (2. - m))
    y_model = observations - crack_model


    # Defining child model random variable
    rv_model = tfd.Independent(
        tfd.Normal(loc=tf.zeros(observations.shape), scale=0.001),
       reinterpreted_batch_ndims=1, name = 'model')
    # Sum of logProbabilities
    return rv_logC.log_prob(logC_norm) + rv_m.log_prob(m_norm) + rv_model.log_prob(y_model)
Finally, it is time to set up the sampler and run a TensorFlow session:
# This cell can take 12 minutes to run in Graph mode
# Number of samples and burnin for the MCMC sampler
samples = 10000
burnin = 10000

# Initial state for the HMC
initial_state = [0., 0.]
# Converting the data into tensors
cycles = tf.convert_to_tensor(t_,tf.float32)
observations = tf.convert_to_tensor(y_,tf.float32)
y0 = tf.convert_to_tensor(y_[0], tf.float32)
# Setting up a target posterior for our joint logprobability
unormalized_target_posterior= lambda *args: joint_log_prob(cycles, observations, y0, *args)
# And finally setting up the mcmc sampler
[logC_samples, m_samples], kernel_results = tfp.mcmc.sample_chain(
    num_results= samples, 
    num_burnin_steps= burnin,
    current_state=initial_state,
    kernel= tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=unormalized_target_posterior,
        step_size = 0.045, 
        num_leapfrog_steps=6))


# Tracking the acceptance rate for the sampled chain
acceptance_rate = tf.reduce_mean(tf.to_float(kernel_results.is_accepted))

# Actually running the sampler
# The evaluate() function, defined at the top of this notebook, runs `sess.run()` 
# in graph mode and allows code to be executed eagerly when Eager mode is enabled
[logC_samples_, m_samples_, acceptance_rate_] = evaluate([
    logC_samples, m_samples, acceptance_rate])

# Some initial results
print('acceptance_rate:', acceptance_rate_)
log graphs
It is interesting to note that although we started off with two independent Gaussian distributions for C and m, the posterior distributions are highly correlated. This correlation arises because the solution space dictates that for high values of m, the only physically meaningful results lie at small values of C and vice-versa. If we were to have used any number of deterministic optimization techniques to find the “best fit” for C and m that fits this dataset, we would have ended up with some value that lies on the straight line depending on our starting point and constraints. Performing a probabilistic optimization (a.k.a. Bayesian calibration) provides us with a global view of all possible solutions that can explain the dataset.

Sampling the posterior for prognostics

And now for the final act, we shall define a posterior function as a log-normal distribution with the expectation defined by the physics model:
physics models
The damage (in this case crack length) is known to be skewed, thus log normal or Gumbel distributions are commonly used to model damage. Expressing the model in TFP looks like this:
def posterior(logC_samples, m_samples, time):
    n_s = len(logC_samples)
    n_inputs = len(time)

    # Some Constants
    dsig = 75.
    B = tf.constant(dsig * np.pi**0.5, tf.float32)

    # Crack Propagation model - compute in the log space

    y_model =(
        time[:,None]  *
        tf.exp(logC_samples[None,:])*
        (1-m_samples[None,:]/2.0) *B**m_samples[None,:] +y0** (1-m_samples[None,:]/2.0))**(2. / (2. - m_samples[None,:]))
    noise = tfd.Normal(loc=0., scale=0.001)
    samples = y_model + noise.sample(n_s)[tf.newaxis,:]
    # The evaluate() function, defined at the top of this notebook, runs `sess.run()` 
    # in graph mode and allows code to be executed eagerly when Eager mode is enabled
    samples_ = evaluate(samples)
    return samples_

# Predict for a range of cycles
time = np.arange(0, 3000, 100)
y_samples = posterior(logC_samples_scale, m_samples_scale, time)
print(y_samples.shape)
The predicted mean with the 95% uncertainty bounds of crack length using the hybrid-physics-probabilistic model is shown below. Clearly, the model captures not only the mean behavior but also an estimate of uncertainty of the model predictions for every time point. As we move away from the observations, the uncertainty in the model predictions increases exponentially.
exponential graph

Next steps

We have chosen this example carefully to provide a gentle introduction to probabilistic models. There are several challenges in applying probabilistic models to real-world applications. Even in a simple problem like the one shown here, the solution becomes intractable if we relax the assumption that the priors are Gaussian. The more inquisitive reader can try using uniform priors to see the dramatic decrease in the predictive capability of the model.

Another subtle point is the way the likelihood is defined. We have chosen to define it on the prediction errors rather than the predicted value directly. Note that the individual blue lines depicted below in “Likelihood based on prediction error” are samples of the model prediction errors. Changing the formulation to model predictions directly — instead of the prediction errors — produces non-monotonic, non-physical results as shown in “Likelihood based on actual predictions”. We call these results non-physical because crack lengths cannot decrease over time. If we were to look at just the “smeared” percentile view shown above, we might not notice the subtle difference in model formulations that could lead to large prediction errors. We will address some of these more practical challenges and their mitigation in the next blogs.
Predication graphs
This is the first of a series of blogs aimed at expanding the use of probabilistic and deep learning techniques for industrial applications with TFP. We (@sarunkarthi) would love to hear about your applications and look forward to seeing these methods used in unique ways. Stay tuned to this blog feed for more updates and examples on anomaly detection, missing data estimation, and forecasting with variational inference.

Acknowledgments

This blog is a result of the hard work and deep collaboration of several teams at Google and BHGE. We are particularly grateful to Josh Dillon, Mike Shwe, Scott Fitzharris, Alex Walker, Christian Hillaire and Fabio Nonato for their many edits, results, code snippets, and — most importantly — enthusiastic support.




Next post
Industrial AI: BHGE’s Physics-based, Probabilistic Deep Learning Using TensorFlow Probability — Part 1

By Arun Subramaniyan, VP Data Science & Analytics at BHGE Digital

Baker Hughes, a GE Company (BHGE), is the world’s leading fullstream oil and gas company with a mission to find better ways to deliver energy to the world. The BHGE Digital team develops enterprise grade, AI-driven, SaaS solutions to improve efficiency and reduce non-productive time for the oil and gas industry. Consider missio…