gennaio 29, 2020 —
Posted by Tom O’Malley
The success of a machine learning project is often crucially dependent on the choice of good hyperparameters. As machine learning continues to mature as a field, relying on trial and error to find good values for these parameters (also known as “grad student descent”) simply doesn’t scale. In fact, many of today’s state-of-the-art results, such as EfficientNet, were discove…
Keras Tuner in action. You can find complete code below. |
hp
argument from which you can sample hyperparameters, such as hp.Int('units', min_value=32, max_value=512, step=32)
(an integer from a certain range). Notice how the hyperparameters can be defined inline with the model-building code. The example below creates a simple tunable model that we’ll train on CIFAR-10:import tensorflow as tf
def build_model(hp):
inputs = tf.keras.Input(shape=(32, 32, 3))
x = inputs
for i in range(hp.Int('conv_blocks', 3, 5, default=3)):
filters = hp.Int('filters_' + str(i), 32, 256, step=32)
for _ in range(2):
x = tf.keras.layers.Convolution2D(
filters, kernel_size=(3, 3), padding='same')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
if hp.Choice('pooling_' + str(i), ['avg', 'max']) == 'max':
x = tf.keras.layers.MaxPool2D()(x)
else:
x = tf.keras.layers.AvgPool2D()(x)
x = tf.keras.layers.GlobalAvgPool2D()(x)
x = tf.keras.layers.Dense(
hp.Int('hidden_size', 30, 100, step=10, default=50),
activation='relu')(x)
x = tf.keras.layers.Dropout(
hp.Float('dropout', 0, 0.5, step=0.1, default=0.5))(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(
optimizer=tf.keras.optimizers.Adam(
hp.Float('learning_rate', 1e-4, 1e-2, sampling='log')),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
Next, instantiate a tuner. You should specify the model-building function, and the name of the objective to optimize (whether to minimize or maximize is automatically inferred for built-in metrics -- for custom metrics you can specify this via the kerastuner.Objective
class). In this example, Keras tuner will use the Hyperband algorithm for the hyperparameter search:import kerastuner as kt
tuner = kt.Hyperband(
build_model,
objective='val_accuracy',
max_epochs=30,
hyperband_iterations=2)
Next we’ll download the CIFAR-10 dataset using TensorFlow Datasets, and then begin the hyperparameter search. To start the search, call the search
method. This method has the same signature as keras.Model.fit
:import tensorflow_datasets as tfds
data = tfds.load('cifar10')
train_ds, test_ds = data['train'], data['test']
def standardize_record(record):
return tf.cast(record['image'], tf.float32) / 255., record['label']
train_ds = train_ds.map(standardize_record).cache().batch(64).shuffle(10000)
test_ds = test_ds.map(standardize_record).cache().batch(64)
tuner.search(train_ds,
validation_data=test_ds,
epochs=30,
callbacks=[tf.keras.callbacks.EarlyStopping(patience=1)])
Each model will train for at most 30 epochs, and two iterations of the Hyperband algorithm will be run. Afterwards, you can retrieve the best models found during the search by using the get_best_models
function:best_model = tuner.get_best_models(1)[0]
You can also view the optimal hyperparameter values found by the search:
best_hyperparameters = tuner.get_best_hyperparameters(1)[0]
And that’s all the code that is needed to perform a sophisticated hyperparameter search!Tuner
like this:tuner = kt.tuners.BayesianOptimization(
kt.applications.HyperResNet(input_shape=(256, 256, 3), classes=10),
objective='val_accuracy',
max_trials=50)
tf.distribute.Strategy
to run each Model on multiple GPUs, and you can also search over multiple different hyperparameter combinations in parallel on different workers. KERASTUNER_TUNER_ID
, KERASTUNER_ORACLE_IP
, and KERASTUNER_ORACLE_PORT
environment variables, for example as shown in the bash script here:export KERASTUNER_TUNER_ID="chief"
export KERASTUNER_ORACLE_IP="127.0.0.1"
export KERASTUNER_ORACLE_PORT="8000"
python run_my_search.py
The tuners coordinate their search via a central Oracle
service that tells each tuner which hyperparameter values to try next. For more information, see our Distributed Tuning guide.class MyTuner(kt.Tuner):
def run_trial(self, trial, ...):
model = self.hypermodel.build(trial.hyperparameters)
score = … # Run the training loop and return the result.
self.oracle.update_trial(trial.trial_id, {'score': score})
self.oracle.save_model(trial.trial_id, model)
For more information, see our Tuner Subclassing guide. from sklearn import ensemble
from sklearn import linear_model
def build_model(hp):
model_type = hp.Choice('model_type', ['random_forest', 'ridge'])
if model_type == 'random_forest':
with hp.conditional_scope('model_type', 'random_forest'):
model = ensemble.RandomForestClassifier(
n_estimators=hp.Int('n_estimators', 10, 50, step=10),
max_depth=hp.Int('max_depth', 3, 10))
elif model_type == 'ridge':
with hp.conditional_scope('model_type', 'ridge'):
model = linear_model.RidgeClassifier(
alpha=hp.Float('alpha', 1e-3, 1, sampling='log'))
else:
raise ValueError('Unrecognized model_type')
return model
tuner = kt.tuners.Sklearn(
oracle=kt.oracles.BayesianOptimization(
objective=kt.Objective('score', 'max'),
max_trials=10),
hypermodel=build_model,
directory=tmp_dir)
X, y = ...
tuner.search(X, y)
For more information on Keras Tuner, please see the Keras Tuner website or the Keras Tuner GitHub. Keras Tuner is an open-source project developed entirely on GitHub. If there are features you’d like to see in Keras Tuner, please open a GitHub issue with a feature request, and if you’re interested in contributing, please take a look at our contribution guidelines and send us a PR!
gennaio 29, 2020
—
Posted by Tom O’Malley
The success of a machine learning project is often crucially dependent on the choice of good hyperparameters. As machine learning continues to mature as a field, relying on trial and error to find good values for these parameters (also known as “grad student descent”) simply doesn’t scale. In fact, many of today’s state-of-the-art results, such as EfficientNet, were discove…