Deep Reinforcement Learning: Playing CartPole through Asynchronous Advantage Actor Critic (A3C) with tf.keras and eager execution
July 31, 2018
By Raymond Yuan, Software Engineering Intern

In this tutorial we will learn how to train a model that is able to win at the simple game CartPole using deep reinforcement learning. We’ll use tf.keras and OpenAI’s gym to train an agent using a technique known as Asynchronous Advantage Actor Critic (A3C). Reinforcement learning has been receiving an enormous amount of attention, but what is it exactly? Reinforcement learning is an area of machine learning that involves agents that should take certain actions from within an environment to maximize or attain some reward.

In the process, we’ll build practical experience and develop intuition around the following concepts:
  • Eager execution — Eager execution is an imperative, define-by-run interface where operations are executed immediately as they are called from Python. This makes it easier to get started with TensorFlow, and can make research and development more intuitive.
  • Model subclassing — Model subclassing allows one to make fully-customizable models by subclassing tf.keras.Model and defining your own forward pass. Model subclassing is particularly useful when eager execution is enabled since the forward pass can be written imperatively.
  • Custom training loops
We’ll follow the basic workflow:
  • Build our master agent supervisor
  • Build our worker agents
  • Implement the A3C algorithm
  • Train our agents
  • Visualize our performance
Audience: This tutorial is targeted towards anybody interested in reinforcement learning. While we won’t go into too much depth into the basics of machine learning, we’ll cover topics such as policy and value networks at a high level. In addition, I suggest reading the paper, Asynchronous Methods for Deep Reinforcement Learning by Volodymyr Mnih, which is a great read and provides much more detail into the algorithm.

What is CartPole?

Cartpole is a game in which a pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The starting state (cart position, cart velocity, pole angle, and pole velocity at tip) is randomly initialized between +/-0.05. The system is controlled by applying a force of +1 or -1 to the cart (moving left or right). The pendulum starts upright, and the goal is to prevent it from falling over. A reward of +1 is provided for every timestep that the pole remains upright. The episode ends when the pole is more than 15 degrees from vertical, or the cart moves more than 2.4 units from the center. cartpole game

Code:

Update 3/32/2021: You can find the complete code for this article at this link.

Establishing a baseline

It’s often very useful to establish a baseline in order to properly judge your model’s actual performance and the metrics by which you evaluate it. For example, your model may seem to be performing well when you see high scores being returned, but in reality the high scores may not be reflective of a good algorithm or the result of random actions. In a classification example, we can establish baseline performance by simply analyzing the class distribution and predicting our most common class. But, how do we establish a baseline for reinforcement learning? In order to do this, we will create a random agent that will simply perform random actions in our environment.
class RandomAgent:
  """Random Agent that will play the specified game

    Arguments:
      env_name: Name of the environment to be played
      max_eps: Maximum number of episodes to run agent for.
  """
  def __init__(self, env_name, max_eps):
    self.env = gym.make(env_name)
    self.max_episodes = max_eps
    self.global_moving_average_reward = 0
    self.res_queue = Queue()

  def run(self):
    reward_avg = 0
    for episode in range(self.max_episodes):
      done = False
      self.env.reset()
      reward_sum = 0.0
      steps = 0
      while not done:
        # Sample randomly from the action space and step
        _, reward, done, _ = self.env.step(self.env.action_space.sample())
        steps += 1
        reward_sum += reward
      # Record statistics
      self.global_moving_average_reward = record(episode, 
                                                 reward_sum, 
                                                 0,
                                                 self.global_moving_average_reward,
                                                 self.res_queue, 0, steps)

      reward_avg += reward_sum
    final_avg = reward_avg / float(self.max_episodes)
    print("Average score across {} episodes: {}".format(self.max_episodes, final_avg))
    return final_avg
For the game CartPole we get an average of ~20 across 4000 episodes. To run the random agent, run the provided py file: python a3c_cartpole.py — algorithm=random — max-eps=4000.

What is the Asynchronous Advantage Actor Critic algorithm?

Asynchronous Advantage Actor Critic is quite a mouthful! Let’s start by breaking down the name, and then the mechanics behind the algorithm itself.
  • Asynchronous: The algorithm is an asynchronous algorithm where multiple worker agents are trained in parallel, each with their own copy of the model and environment. This allows our algorithm to not only train faster as more workers are training in parallel, but also to attain a more diverse training experience as each workers’ experience is independent.
  • Advantage: Advantage is a metric to judge both how good its actions were, but also how they turned out. This allows the algorithm to focus on where the network’s predictions were lacking. Intuitively, this allows us to measure the advantage of taking action, a, over following the policy π at the given time step.
  • Actor-Critic: The Actor-Critic aspect of the algorithm uses an architecture that shares layers between the policy and value function.

But how does it work?

At a high level, the A3C algorithm uses an asynchronous updating scheme that operates on fixed-length time steps of experience. It will use these segments to compute estimators of the rewards and the advantage function. Each worker performs the following workflow cycle:
  1. Fetch the global network parameters
  2. Interact with the environment by following the local policy for min(t_max, steps to terminal state) number of steps.
  3. Calculate value and policy loss
  4. Get gradients from losses
  5. Update the global network with the gradients
  6. Repeat
With this training configuration, we expect to see a linear speed up with the number of agents. However, the number of agents your machine can support is bound by the number of CPU cores available. In addition, A3C can even scale to more than one machine, and some newer research (such as IMPALA) supports scaling it even further. Adding more may prove detrimental to speed and performance. Check out the paper for more in-depth information!

Refresher for policy and value functions

If you are already comfortable with policy gradients, then I would recommend skipping this section. Otherwise, if you don’t know what policies/values are or just want a quick refresher, read on!

The idea of a policy is to parametrise the action probability distribution given some input state. We do so by creating a network that takes the state of the game and decides what we should do. As such, while the agent is playing the game, whenever it sees a certain state (or similar states), it will compute the probability of each available action given an input state, and then sample an action according to this probability distribution. To delve into the mathematics more formally, policy gradients are a special case of the more general score function gradient estimator. The general case is expressed in the form of Ex p(x | ) [f(x)]; in other words and in our case, the expectation of our reward (or advantage) function, f, under some policy network, p. Then, using the log derivative trick, we figure out how to update our network’s parameters such that the action samples obtain higher rewards and end up with ∇ Ex[f(x)] =Ex[f(x) ∇ log p(x)]. In English, this equation explains how shifting θ in the direction of our gradient will maximize our scores according to our reward function f.

Value functions essentially judge how good a certain state is to be in. Formally, value functions define the expected sum of rewards when starting in state s and following policy p. This is where the “Critic” part of the name becomes relevant. The agent uses the value estimate (the critic) to update the policy (the actor).

Implementation

Let’s first define what kind of model we’ll be using. The master agent will have the global network and each local worker agent will have a copy of this network in their own process. We’ll instantiate the model using Model Subclassing. Model subclassing gives us maximum flexibility at the cost of higher verbosity.
class ActorCriticModel(keras.Model):
  def __init__(self, state_size, action_size):
    super(ActorCriticModel, self).__init__()
    self.state_size = state_size
    self.action_size = action_size
    self.dense1 = layers.Dense(100, activation='relu')
    self.policy_logits = layers.Dense(action_size)
    self.dense2 = layers.Dense(100, activation='relu')
    self.values = layers.Dense(1)

  def call(self, inputs):
    # Forward pass
    x = self.dense1(inputs)
    logits = self.policy_logits(x)
    v1 = self.dense2(inputs)
    values = self.values(v1)
    return logits, values
As you can see from our forward pass, our model will take inputs and return policy probability logits and values.

Master Agent — Main thread

Let’s get to the brain of the operation. The master agent holds a shared optimizer that updates its global network. This agent instantiates the global network that each worker agent will update as well as the optimizer that we will use to update it. A3C is shown to be quite resilient to a spread of learning rates, but we’ll use the AdamOptimizer with a learning rate of 5e-4 for the game CartPole.
class MasterAgent():
  def __init__(self):
    self.game_name = 'CartPole-v0'
    save_dir = args.save_dir
    self.save_dir = save_dir
    if not os.path.exists(save_dir):
      os.makedirs(save_dir)

    env = gym.make(self.game_name)
    self.state_size = env.observation_space.shape[0]
    self.action_size = env.action_space.n
    self.opt = tf.train.AdamOptimizer(args.lr, use_locking=True)
    print(self.state_size, self.action_size)

    self.global_model = ActorCriticModel(self.state_size, self.action_size)  # global network
    self.global_model(tf.convert_to_tensor(np.random.random((1, self.state_size)), dtype=tf.float32))
The master agent will run the train function to instantiate and start each of the agents. The master agent handles the coordinating and supervision of each agent. Each of these agents will run asynchronously. (This is technically not true asynchrony because in Python, because of GIL (Global Interpreter Lock) a single python process cannot run threads in parallel (utilize multiple cores). It can however run them concurrently (context switch during I/O bound operations). We implement with threads for the sake of simplicity and clarity of example.)
def train(self):
    if args.algorithm == 'random':
      random_agent = RandomAgent(self.game_name, args.max_eps)
      random_agent.run()
      return

    res_queue = Queue()

    workers = [Worker(self.state_size,
                      self.action_size,
                      self.global_model,
                      self.opt, res_queue,
                      i, game_name=self.game_name,
                      save_dir=self.save_dir) for i in range(multiprocessing.cpu_count())]

    for i, worker in enumerate(workers):
      print("Starting worker {}".format(i))
      worker.start()

    moving_average_rewards = []  # record episode reward to plot
    while True:
      reward = res_queue.get()
      if reward is not None:
        moving_average_rewards.append(reward)
      else:
        break
    [w.join() for w in workers]

    plt.plot(moving_average_rewards)
    plt.ylabel('Moving average ep reward')
    plt.xlabel('Step')
    plt.savefig(os.path.join(self.save_dir,
                             '{} Moving Average.png'.format(self.game_name)))
    plt.show()

Memory Class — Holding our experience

In addition, to make keeping track of things easier, we’ll also implement a Memory class. This class will simply give us the functionality to keep track of our actions, rewards, states that occur per step.
class Memory:
  def __init__(self):
    self.states = []
    self.actions = []
    self.rewards = []
    
  def store(self, state, action, reward):
    self.states.append(state)
    self.actions.append(action)
    self.rewards.append(reward)
    
  def clear(self):
    self.states = []
    self.actions = []
    self.rewards = []
Now, we get to the crux of the algorithm: the worker agent. The worker agent inherits from the threading class, and we override the run method from Thread. This will allow us to achieve the first A in A3C, asynchronous. First we’ll begin by instantiating a local model and setting up the specific training parameters.
class Worker(threading.Thread):
  # Set up global variables across different threads
  global_episode = 0
  # Moving average reward
  global_moving_average_reward = 0
  best_score = 0
  save_lock = threading.Lock()

  def __init__(self,
               state_size,
               action_size,
               global_model,
               opt,
               result_queue,
               idx,
               game_name='CartPole-v0',
               save_dir='/tmp'):
    super(Worker, self).__init__()
    self.state_size = state_size
    self.action_size = action_size
    self.result_queue = result_queue
    self.global_model = global_model
    self.opt = opt
    self.local_model = ActorCriticModel(self.state_size, self.action_size)
    self.worker_idx = idx
    self.game_name = game_name
    self.env = gym.make(self.game_name).unwrapped
    self.save_dir = save_dir
    self.ep_loss = 0.0

Running the algorithm

The next step is to implement the run function. This will actually run our algorithm. We’ll run all the threads for a given global maximum number of episodes. This is where the third A, actor, of A3C comes into play. Our agent will “act” according to our policy function, becoming the actor while the action is judged by the “critic,” which is our value function. While this section of the code may seem dense, it really isn’t doing much. Within each episode, the code simply does:
  1. Get our policy (action probability distribution) based on the current frame
  2. Step with action chosen according to policy
  3. If the agent has taken a set number of steps (args.update_freq) or the agent has reached a terminal state (has died), then update the global model with gradients computed from local model
  4. Repeat
def run(self):
    total_step = 1
    mem = Memory()
    while Worker.global_episode < args.max_eps:
      current_state = self.env.reset()
      mem.clear()
      ep_reward = 0.
      ep_steps = 0
      self.ep_loss = 0

      time_count = 0
      done = False
      while not done:
        logits, _ = self.local_model(
            tf.convert_to_tensor(current_state[None, :],
                                 dtype=tf.float32))
        probs = tf.nn.softmax(logits)

        action = np.random.choice(self.action_size, p=probs.numpy()[0])
        new_state, reward, done, _ = self.env.step(action)
        if done:
          reward = -1
        ep_reward += reward
        mem.store(current_state, action, reward)

        if time_count == args.update_freq or done:
          # Calculate gradient wrt to local model. We do so by tracking the
          # variables involved in computing the loss by using tf.GradientTape
          with tf.GradientTape() as tape:
            total_loss = self.compute_loss(done,
                                           new_state,
                                           mem,
                                           args.gamma)
          self.ep_loss += total_loss
          # Calculate local gradients
          grads = tape.gradient(total_loss, self.local_model.trainable_weights)
          # Push local gradients to global model
          self.opt.apply_gradients(zip(grads,
                                       self.global_model.trainable_weights))
          # Update local model with new weights
          self.local_model.set_weights(self.global_model.get_weights())

          mem.clear()
          time_count = 0

          if done:  # done and print information
            Worker.global_moving_average_reward = \
              record(Worker.global_episode, ep_reward, self.worker_idx,
                     Worker.global_moving_average_reward, self.result_queue,
                     self.ep_loss, ep_steps)
            # We must use a lock to save our model and to print to prevent data races.
            if ep_reward > Worker.best_score:
              with Worker.save_lock:
                print("Saving best model to {}, "
                      "episode score: {}".format(self.save_dir, ep_reward))
                self.global_model.save_weights(
                    os.path.join(self.save_dir,
                                 'model_{}.h5'.format(self.game_name))
                )
                Worker.best_score = ep_reward
            Worker.global_episode += 1
        ep_steps += 1

        time_count += 1
        current_state = new_state
        total_step += 1
    self.result_queue.put(None)

How are the losses computed?

The worker agent calculates the losses to obtain gradients with respect to all of its own network parameters. This is where the last A of A3C come into play, the advantage. These are then applied to the global network. The losses are calculated as:
  1. Value Loss: L=∑(R — V(s))²
  2. Policy Loss: L = -log(𝝅(s)) * A(s)
Where R is the discounted rewards, V our value function (of an input state), 𝛑 our policy function (of an input state as well), and A our advantage function. We use the discounted rewards to estimate our Q value since we don’t directly determine the Q value with A3C.
def compute_loss(self,
                   done,
                   new_state,
                   memory,
                   gamma=0.99):
    if done:
      reward_sum = 0.  # terminal
    else:
      reward_sum = self.local_model(
          tf.convert_to_tensor(new_state[None, :],
                               dtype=tf.float32))[-1].numpy()[0]

    # Get discounted rewards
    discounted_rewards = []
    for reward in memory.rewards[::-1]:  # reverse buffer r
      reward_sum = reward + gamma * reward_sum
      discounted_rewards.append(reward_sum)
    discounted_rewards.reverse()

    logits, values = self.local_model(
        tf.convert_to_tensor(np.vstack(memory.states),
                             dtype=tf.float32))
    # Get our advantages
    advantage = tf.convert_to_tensor(np.array(discounted_rewards)[:, None],
                            dtype=tf.float32) - values
    # Value loss
    value_loss = advantage ** 2

    # Calculate our policy loss
    actions_one_hot = tf.one_hot(memory.actions, self.action_size, dtype=tf.float32)

    policy = tf.nn.softmax(logits)
    entropy = tf.reduce_sum(policy * tf.log(policy + 1e-20), axis=1)

    policy_loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=actions_one_hot,
                                                             logits=logits)
    policy_loss *= tf.stop_gradient(advantage)
    policy_loss -= 0.01 * entropy
    total_loss = tf.reduce_mean((0.5 * value_loss + policy_loss))
    return total_loss
That’s it! The worker agent will repeat the process of resetting the network parameters to all of those in the global network, and repeating the process of interacting with its environment, computing the loss, and then applying the gradients to the global network. Train your algorithm by running the command: python a3c_cartpole.py — train.

Testing the algorithm

Let’s test the algorithm by spinning up a new environment and simply following our policy output by our trained model. This will render our environment and sample from the policy distribution from our model.
 def play(self):
    env = gym.make(self.game_name).unwrapped
    state = env.reset()
    model = self.global_model
    model_path = os.path.join(self.save_dir, 'model_{}.h5'.format(self.game_name))
    print('Loading model from: {}'.format(model_path))
    model.load_weights(model_path)
    done = False
    step_counter = 0
    reward_sum = 0

    try:
      while not done:
        env.render(mode='rgb_array')
        policy, value = model(tf.convert_to_tensor(state[None, :], dtype=tf.float32))
        policy = tf.nn.softmax(policy)
        action = np.argmax(policy)
        state, reward, done, _ = env.step(action)
        reward_sum += reward
        print("{}. Reward: {}, action: {}".format(step_counter, reward_sum, action))
        step_counter += 1
    except KeyboardInterrupt:
      print("Received Keyboard Interrupt. Shutting down.")
    finally:
      env.close()
You can run this with the following command: python a3c_cartpole.py once you’ve trained your model.

To examine our moving average of scores: We should see a score converged with scores >200. The game is defined as being “solved” as getting average reward of 195.0 over 100 consecutive trials.

And an example performance on a new environment:

Key Takeaways

What we covered:

  • We solved CartPole by implementing A3C!
  • We did so by using Eager Execution, Model Subclassing, and Custom Training loops.
  • Eager is an easy way to develop training loops that makes coding easier and clearer since we’re able to print and debug tensors directly.
  • We learned the basics of reinforcement learning with policy and value networks, and then we tied them together to implement A3C.
  • We iteratively updated our global network by applying our optimizer’s update rules using tf.gradient.
Next post
Deep Reinforcement Learning: Playing CartPole through Asynchronous Advantage Actor Critic (A3C) with tf.keras and eager execution

By Raymond Yuan, Software Engineering Intern

In this tutorial we will learn how to train a model that is able to win at the simple game CartPole using deep reinforcement learning. We’ll use tf.keras and OpenAI’s gym to train an agent using a technique known as Asynchronous Advantage Actor Critic (A3C). Reinforcement learning has been receiving an enormous amount of attention, but what is it exact…