Hyperparameter tuning with Keras Tuner
January 29, 2020
Posted by Tom O’Malley

The success of a machine learning project is often crucially dependent on the choice of good hyperparameters. As machine learning continues to mature as a field, relying on trial and error to find good values for these parameters (also known as “grad student descent”) simply doesn’t scale. In fact, many of today’s state-of-the-art results, such as EfficientNet, were discovered via sophisticated hyperparameter optimization algorithms.

Keras Tuner is an easy-to-use, distributable hyperparameter optimization framework that solves the pain points of performing a hyperparameter search. Keras Tuner makes it easy to define a search space and leverage included algorithms to find the best hyperparameter values. Keras Tuner comes with Bayesian Optimization, Hyperband, and Random Search algorithms built-in, and is also designed to be easy for researchers to extend in order to experiment with new search algorithms.
Keras Tuner in action. You can find complete code below.
Here’s a simple end-to-end example. First, we define a model-building function. It takes an hp argument from which you can sample hyperparameters, such as hp.Int('units', min_value=32, max_value=512, step=32) (an integer from a certain range). Notice how the hyperparameters can be defined inline with the model-building code. The example below creates a simple tunable model that we’ll train on CIFAR-10:
import tensorflow as tf

def build_model(hp):
  inputs = tf.keras.Input(shape=(32, 32, 3))
  x = inputs
  for i in range(hp.Int('conv_blocks', 3, 5, default=3)):
    filters = hp.Int('filters_' + str(i), 32, 256, step=32)
    for _ in range(2):
      x = tf.keras.layers.Convolution2D(
        filters, kernel_size=(3, 3), padding='same')(x)
      x = tf.keras.layers.BatchNormalization()(x)
      x = tf.keras.layers.ReLU()(x)
    if hp.Choice('pooling_' + str(i), ['avg', 'max']) == 'max':
      x = tf.keras.layers.MaxPool2D()(x)
    else:
      x = tf.keras.layers.AvgPool2D()(x)
  x = tf.keras.layers.GlobalAvgPool2D()(x)
  x = tf.keras.layers.Dense(
      hp.Int('hidden_size', 30, 100, step=10, default=50),
      activation='relu')(x)
  x = tf.keras.layers.Dropout(
      hp.Float('dropout', 0, 0.5, step=0.1, default=0.5))(x)
  outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

  model = tf.keras.Model(inputs, outputs)
  model.compile(
    optimizer=tf.keras.optimizers.Adam(
      hp.Float('learning_rate', 1e-4, 1e-2, sampling='log')),
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])
  return model
Next, instantiate a tuner. You should specify the model-building function, and the name of the objective to optimize (whether to minimize or maximize is automatically inferred for built-in metrics -- for custom metrics you can specify this via the kerastuner.Objective class). In this example, Keras tuner will use the Hyperband algorithm for the hyperparameter search:
import kerastuner as kt

tuner = kt.Hyperband(
    build_model,
    objective='val_accuracy',
    max_epochs=30,
    hyperband_iterations=2)
Next we’ll download the CIFAR-10 dataset using TensorFlow Datasets, and then begin the hyperparameter search. To start the search, call the search method. This method has the same signature as keras.Model.fit:
import tensorflow_datasets as tfds

data = tfds.load('cifar10')
train_ds, test_ds = data['train'], data['test']

def standardize_record(record):
  return tf.cast(record['image'], tf.float32) / 255., record['label']

train_ds = train_ds.map(standardize_record).cache().batch(64).shuffle(10000)
test_ds = test_ds.map(standardize_record).cache().batch(64)

tuner.search(train_ds,
             validation_data=test_ds,
             epochs=30,
             callbacks=[tf.keras.callbacks.EarlyStopping(patience=1)])
Each model will train for at most 30 epochs, and two iterations of the Hyperband algorithm will be run. Afterwards, you can retrieve the best models found during the search by using the get_best_models function:
best_model = tuner.get_best_models(1)[0]
You can also view the optimal hyperparameter values found by the search:

best_hyperparameters = tuner.get_best_hyperparameters(1)[0]
And that’s all the code that is needed to perform a sophisticated hyperparameter search!

You can find the complete code for the example above here.

Built-in Tunable Models

In addition to allowing you to define your own tunable models, Keras Tuner provides two built-in tunable models: HyperResnet and HyperXception. These models search over various permutations of the ResNet and Xception architectures, respectively. These models can be used with a Tuner like this:
tuner = kt.tuners.BayesianOptimization(
  kt.applications.HyperResNet(input_shape=(256, 256, 3), classes=10),
  objective='val_accuracy',
  max_trials=50)

Distributed Tuning

With Keras Tuner, you can do both data-parallel and trial-parallel distribution. That is, you can use tf.distribute.Strategy to run each Model on multiple GPUs, and you can also search over multiple different hyperparameter combinations in parallel on different workers.

No code changes are needed to perform a trial-parallel search. Simply set the KERASTUNER_TUNER_ID, KERASTUNER_ORACLE_IP, and KERASTUNER_ORACLE_PORT environment variables, for example as shown in the bash script here:
export KERASTUNER_TUNER_ID="chief"
export KERASTUNER_ORACLE_IP="127.0.0.1"
export KERASTUNER_ORACLE_PORT="8000"
python run_my_search.py
The tuners coordinate their search via a central Oracle service that tells each tuner which hyperparameter values to try next. For more information, see our Distributed Tuning guide.

Custom Training Loops

The `kerastuner.Tuner` class can be subclassed to support advanced uses such as:
  • Custom training loops (GANs, reinforcement learning, etc.)
  • Adding hyperparameters outside of the model building function (preprocessing, data augmentation, test time augmentation, etc.)
Here’s a simple example:
class MyTuner(kt.Tuner):

    def run_trial(self, trial, ...):
        model = self.hypermodel.build(trial.hyperparameters)
        score = …  # Run the training loop and return the result.
        self.oracle.update_trial(trial.trial_id, {'score': score})
        self.oracle.save_model(trial.trial_id, model)
For more information, see our Tuner Subclassing guide.

Tuning Scikit-learn Models

Despite its name, Keras Tuner can be used to tune a wide variety of machine learning models. In addition to built-in Tuners for Keras models, Keras Tuner provides a built-in Tuner that works with Scikit-learn models. Here’s a simple example of how to use this tuner:
from sklearn import ensemble
from sklearn import linear_model

def build_model(hp):
    model_type = hp.Choice('model_type', ['random_forest', 'ridge'])
    if model_type == 'random_forest':
        with hp.conditional_scope('model_type', 'random_forest'):
            model = ensemble.RandomForestClassifier(
                n_estimators=hp.Int('n_estimators', 10, 50, step=10),
                max_depth=hp.Int('max_depth', 3, 10))
    elif model_type == 'ridge':
        with hp.conditional_scope('model_type', 'ridge'):
            model = linear_model.RidgeClassifier(
                alpha=hp.Float('alpha', 1e-3, 1, sampling='log'))
    else:
        raise ValueError('Unrecognized model_type')
    return model

tuner = kt.tuners.Sklearn(
        oracle=kt.oracles.BayesianOptimization(
            objective=kt.Objective('score', 'max'),
            max_trials=10),
        hypermodel=build_model,
        directory=tmp_dir)
X, y = ...
tuner.search(X, y)
For more information on Keras Tuner, please see the Keras Tuner website or the Keras Tuner GitHub. Keras Tuner is an open-source project developed entirely on GitHub. If there are features you’d like to see in Keras Tuner, please open a GitHub issue with a feature request, and if you’re interested in contributing, please take a look at our contribution guidelines and send us a PR!
Next post
Hyperparameter tuning with Keras Tuner

Posted by Tom O’Malley

The success of a machine learning project is often crucially dependent on the choice of good hyperparameters. As machine learning continues to mature as a field, relying on trial and error to find good values for these parameters (also known as “grad student descent”) simply doesn’t scale. In fact, many of today’s state-of-the-art results, such as EfficientNet, were discove…