October 03, 2022 — Posted by Wei Wei, Developer AdvocateIn our previous blog post Building a board game app with TensorFlow: a new TensorFlow Lite reference app, we showed you how to use TensorFlow and TensorFlow Agents to train a reinforcement learning (RL) agent to play a simple board game ‘Plane Strike’. We also converted the trained model to TensorFlow Lite and then deployed it into a fully-functional Android a…
Posted by Wei Wei, Developer Advocate
In our previous blog post Building a board game app with TensorFlow: a new TensorFlow Lite reference app, we showed you how to use TensorFlow and TensorFlow Agents to train a reinforcement learning (RL) agent to play a simple board game ‘Plane Strike’. We also converted the trained model to TensorFlow Lite and then deployed it into a fully-functional Android app. In this blog, we will demonstrate a new path: train the same RL agent with Flax/JAX and deploy it into the same Android app we have built before. The complete code has been open sourced in the tensorflow/examples repository for your reference.
To refresh your memory, our RL-based agent needs to predict a strike position based on the human player’s board position so that it can finish the game before the human player does. For more detailed game rules, please refer to our previous blog.
Demo game play in ‘Plane Strike’ |
where:
We define a 3-layer MLP as our policy network, which predicts the agent’s next strike position.
class PolicyGradient(nn.Module): """Neural network to predict the next strike position."""
@nn.compact def __call__(self, x): dtype = jnp.float32 x = x.reshape((x.shape[0], -1)) x = nn.Dense( features=2 * common.BOARD_SIZE**2, name='hidden1', dtype=dtype)( x) x = nn.relu(x) x = nn.Dense(features=common.BOARD_SIZE**2, name='hidden2', dtype=dtype)(x) x = nn.relu(x) x = nn.Dense(features=common.BOARD_SIZE**2, name='logits', dtype=dtype)(x) policy_probabilities = nn.softmax(x) return policy_probabilities |
for i in tqdm(range(iterations)): predict_fn = functools.partial(run_inference, params) board_log, action_log, result_log = common.play_game(predict_fn) rewards = common.compute_rewards(result_log) optimizer, params, opt_state = train_step(optimizer, params, opt_state, board_log, action_log, rewards) |
def compute_loss(logits, labels, rewards): one_hot_labels = jax.nn.one_hot(labels, num_classes=common.BOARD_SIZE**2) loss = -jnp.mean( jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards)) return loss
def train_step(model_optimizer, params, opt_state, game_board_log, predicted_action_log, action_result_log): """Run one training step."""
def loss_fn(model_params): logits = run_inference(model_params, game_board_log) loss = compute_loss(logits, predicted_action_log, action_result_log) return loss
def compute_grads(params): return jax.grad(loss_fn)(params)
grads = compute_grads(params) updates, opt_state = model_optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) return model_optimizer, params, opt_state
@jax.jit def run_inference(model_params, board): logits = PolicyGradient().apply({'params': model_params}, board) return logits |
# Convert to tflite model model = PolicyGradient() jax_predict_fn = lambda input: model.apply({'params': params}, input)
tf_predict = tf.function( jax2tf.convert(jax_predict_fn, enable_xla=False), input_signature=[ tf.TensorSpec( shape=[1, common.BOARD_SIZE, common.BOARD_SIZE], dtype=tf.float32, name='input') ], autograph=False, )
converter = tf.lite.TFLiteConverter.from_concrete_functions( [tf_predict.get_concrete_function()], tf_predict)
tflite_model = converter.convert()
# Save the model with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f: f.write(tflite_model) |
Visualizing TFLite model converted from Flax/JAX using Netron |
convertBoardStateToByteBuffer(board); |
October 03, 2022 — Posted by Wei Wei, Developer AdvocateIn our previous blog post Building a board game app with TensorFlow: a new TensorFlow Lite reference app, we showed you how to use TensorFlow and TensorFlow Agents to train a reinforcement learning (RL) agent to play a simple board game ‘Plane Strike’. We also converted the trained model to TensorFlow Lite and then deployed it into a fully-functional Android a…