Google Article
Speeding up neural networks using TensorNetwork in Keras
februar 12, 2020
Posted by Marina Munkhoeva, PhD student at Skolkovo Institute of Science and Technology and AI Resident at Alphabet's X, Chase Roberts, Research Engineer at Alphabet's X, and Stefan Leichenauer, Research Scientist at Alphabet's X

Introduction

In this post, we’re going to talk about TensorNetwork, and how it can be used to supercharge a feed-forward neural network in TensorFlow. TensorNetwork is an open-source library released in June ‘19 to facilitate computations with tensor networks. Usually, the first question people ask us is “What is a tensor network?”, closely followed by “Why should I care about tensor networks?”. There are many resources out there addressing the first question (such as our previous Google AI blog post), and here we’re going to focus on answering the second. The basic idea is called “tensorizing” a neural network and has its roots in a 2015 paper from Novikov et. al. Using the TensorNetwork library, it’s straightforward to implement this procedure. Below we’ll give an explicit and pedagogical example using Keras and TensorFlow 2.0.

Getting started with TensorNetwork is easy. The library can be installed using pip:
pip install tensornetwork
The example code we’ll discuss in this post is also available in a Colab.

TN Layers

The basic idea of a TN-enhanced neural network is to replace one or more of the layers of the network with a TN layer. You can think of the TN layer as a compressed version of the original layer. The best results can be expected when you start out with a dense, highly-connected layer that has a lot of potential for compression. For example, consider the following:
Dense = tf.keras.layers.Dense
fc_model = tf.keras.Sequential(
    [
     tf.keras.Input(shape=(2,)),
     Dense(1024, activation=tf.nn.swish),
     Dense(1024, activation=tf.nn.swish),
     Dense(1, activation=None)])
This network has a two-dimensional input, one-dimensional output, and contains two hidden layers with 1024 neurons each. The total number of parameters (including weights and biases) is (2+1)*1024 + (1024+1)*1024 + (1024 +1)*1 = 1,053,697. The bulk of these parameters, (1024+1)*1024 = 1,049,600 of them, are associated with the second hidden layer. We’re going to replace those with a TN layer that is less than 1% of the original size.

The 1,049,600 parameters of the hidden layer consist of 1024*1024 = 1,048,576 weights arranged in a 1024-by-1024 matrix, plus 1024 biases. Our savings are going to come from replacing the 1024-by-1024 matrix of weights by a tensor network. To preserve compatibility with the rest of the network, we don’t change the fact that there are 1024 inputs and 1024 outputs. That way we can treat the TN layer as a drop-in replacement for the original layer.

Ordinarily, we think of the 1024 inputs to the layer as a simple array X of shape (1024,). These inputs are multiplied by the weight matrix W to produce a new vector Y. The next step is to apply a nonlinearity to Y, but we won’t focus on that because the TN modifications only affect the linear part of the layer. In tensor network notation, we would write the computation of Y as follows:
Ordinarily, the weights W of a neural network layer act on the inputs X as a matrix multiplication, producing the output Y = WX. Here we illustrate that matrix multiplication diagrammatically.
At this point, we can discuss tensorizing the layer. First, we reshape the input array into something like (32,32) instead of (1024,). Pictorially, the reshape looks like
Reshaping the input X from a vector of shape (1024,) to an array of shape (32,32) is the first step of the tensorization process. In principle any reshape is allowed, we just chose this particularly simple one for our example.
Now, instead of applying the (1024,1024) weight matrix W, we apply a tensor network operation consisting of two cores:
Diagrammatic representation of the tensorized weight multiplication. The output Y can be reshaped into a vector if needed.
We need to make a choice about the dimensionality of the leg connecting the two cores (often called the bond dimension). For simplicity, we’ll take it to be two-dimensional. Then each core has two thirty two-dimensional legs along with the single two-dimensional leg, meaning that the shape is (32,32,2). The bond dimension controls the number of parameters in the model, and by appropriately choosing the bond dimension we achieve a good parameter reduction rate with little or no performance drop in many cases.

The result of using a TN layer is that we’ve replaced the 1,048,576 weights of the fully-connected weight matrix with the 2*(32*32*2) = 4,096 parameters of the tensor network. That’s a tremendous reduction! Even after accounting for the other layers, the total model size is down to 9,217 parameters, compared to the original 1,053,697.

For this simple example, it would not be too hard to translate the two-core tensor network into an einsum expression. However, for a tensor network with more cores and more complex connectivity, debugging or extending the einsum can be quite difficult. Thus, we instead use the TN library for a more object oriented way of constructing the network. In the code examples below we’ll stick to the simple two-core example to make things easy to follow, but at the end we’ll come back to discuss other possibilities.

Code for a TN Layer

Here is some sample code (also available in this Colab) for creating a TN layer in Keras, specialized to the 1024-by-1024 case discussed above:
import tensorflow as tf
import tensornetwork as tn

class TNLayer(tf.keras.layers.Layer):
 
  def __init__(self):
    super(TNLayer, self).__init__()
    # Create the variables for the layer.
    self.a_var = tf.Variable(tf.random.normal(
            shape=(32, 32, 2), stddev=1.0/32.0),
             name="a", trainable=True)
    self.b_var = tf.Variable(tf.random.normal(shape=(32, 32, 2), stddev=1.0/32.0),
                             name="b", trainable=True)
    self.bias = tf.Variable(tf.zeros(shape=(32, 32)), name="bias", trainable=True)
 
  def call(self, inputs):
    # Define the contraction.
    # We break it out so we can parallelize a batch using
    # tf.vectorized_map (see below).
    def f(input_vec, a_var, b_var, bias_var):
      # Reshape to a matrix instead of a vector.
      input_vec = tf.reshape(input_vec, (32,32))
 
      # Now we create the network.
      a = tn.Node(a_var, backend="tensorflow")
      b = tn.Node(b_var, backend="tensorflow")
      x_node = tn.Node(input_vec, backend="tensorflow")
      a[1] ^ x_node[0]
      b[1] ^ x_node[1]
      a[2] ^ b[2]
 
      # The TN should now look like this
      #   |     |
      #   a --- b
      #    \   /
      #      x
 
      # Now we begin the contraction.
      c = a @ x_node
      result = (c @ b).tensor
 
      # To make the code shorter, we also could've used Ncon.
      # The above few lines of code is the same as this:
      # result = tn.ncon([x, a_var, b_var], [[1, 2], [-1, 1, 3], [-2, 2, 3]])
 
      # Finally, add bias.
      return result + bias_var
  
    # To deal with a batch of items, we can use the tf.vectorized_map
    # function.
    # https://www.tensorflow.org/api_docs/python/tf/vectorized_map
    result = tf.vectorized_map(
        lambda vec: f(vec, self.a_var, self.b_var, self.bias), inputs)
    return tf.nn.relu(tf.reshape(result, (-1, 1024)))
In this example, we hard-coded the size of the layer, but that is fairly easy to adjust. Having made this layer, we can use it as part of a Keras model very simply:
tn_model = tf.keras.Sequential(
  [
    tf.keras.Input(shape=(2,)),
    Dense(1024, activation=tf.nn.relu),
    # Here use a TN layer instead of the dense layer.
    TNLayer(),
    Dense(1, activation=None)
  ]
)
The model can be trained as usual using the Keras fit method.

Compression of a Transformer and More

While the simple example we discussed above is a nice illustration of the ideas of tensorization, that particular model would probably not be very useful in practice. For something more realistic, we recently experimented with tensorizing a Transformer model in a very similar way. The model we looked at had dense layers that were much larger: the original model had 236M parameters! We tensorized fully-connected layers in 8 Transformer blocks with four-core tensor networks:
A four-core tensor network that we used to tensorize the fully-connected layers of a Transformer model.
After tensorization, the model shrank to about 101M parameters! Besides being much smaller, the tensorized model also generated English sentences significantly faster. Here is a side-by-side video showing the difference in performance:



There are a number of other types of tensor networks and tensor decomposition which are applicable to various types of layers in neural networks, and we won’t make an attempt to give an exhaustive reference list here. But for a start, check out Khrulkov et al. for tensor networks in embedding layers, Ma et al. for attention layers, and Lebedev et al. for convolutional layers.

Tensorization of neural networks is still in its infancy. There are still many unanswered questions, and many experiments to try. All of the tensor networks considered in this post are of the tensor train type, known in physics as an MPO (see also this paper), but other well-studied tensor networks like PEPS and MERA could also be used. The TensorNetwork library was designed to handle any possible tensor network, and specifically to do so in the context of machine learning as we have presented here. We look forward to seeing all of the new and exciting results that you can cook up using it.
Next post
Speeding up neural networks using TensorNetwork in Keras

Posted by Marina Munkhoeva, PhD student at Skolkovo Institute of Science and Technology and AI Resident at Alphabet's X, Chase Roberts, Research Engineer at Alphabet's X, and Stefan Leichenauer, Research Scientist at Alphabet's X

Introduction In this post, we’re going to talk about TensorNetwork, and how it can be used to supercharge a feed-forward neural network in TensorFlow. Tensor…