7월 31, 2018 —
By Raymond Yuan, Software Engineering Intern
In this tutorial we will learn how to train a model that is able to win at the simple game CartPole using deep reinforcement learning. We’ll use tf.keras and OpenAI’s gym to train an agent using a technique known as Asynchronous Advantage Actor Critic (A3C). Reinforcement learning has been receiving an enormous amount of attention, but what is it exact…
class RandomAgent:
"""Random Agent that will play the specified game
Arguments:
env_name: Name of the environment to be played
max_eps: Maximum number of episodes to run agent for.
"""
def __init__(self, env_name, max_eps):
self.env = gym.make(env_name)
self.max_episodes = max_eps
self.global_moving_average_reward = 0
self.res_queue = Queue()
def run(self):
reward_avg = 0
for episode in range(self.max_episodes):
done = False
self.env.reset()
reward_sum = 0.0
steps = 0
while not done:
# Sample randomly from the action space and step
_, reward, done, _ = self.env.step(self.env.action_space.sample())
steps += 1
reward_sum += reward
# Record statistics
self.global_moving_average_reward = record(episode,
reward_sum,
0,
self.global_moving_average_reward,
self.res_queue, 0, steps)
reward_avg += reward_sum
final_avg = reward_avg / float(self.max_episodes)
print("Average score across {} episodes: {}".format(self.max_episodes, final_avg))
return final_avg
For the game CartPole we get an average of ~20 across 4000 episodes. To run the random agent, run the provided py file: python a3c_cartpole.py — algorithm=random — max-eps=4000
.
class ActorCriticModel(keras.Model):
def __init__(self, state_size, action_size):
super(ActorCriticModel, self).__init__()
self.state_size = state_size
self.action_size = action_size
self.dense1 = layers.Dense(100, activation='relu')
self.policy_logits = layers.Dense(action_size)
self.dense2 = layers.Dense(100, activation='relu')
self.values = layers.Dense(1)
def call(self, inputs):
# Forward pass
x = self.dense1(inputs)
logits = self.policy_logits(x)
v1 = self.dense2(inputs)
values = self.values(v1)
return logits, values
As you can see from our forward pass, our model will take inputs and return policy probability logits and values.
class MasterAgent():
def __init__(self):
self.game_name = 'CartPole-v0'
save_dir = args.save_dir
self.save_dir = save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
env = gym.make(self.game_name)
self.state_size = env.observation_space.shape[0]
self.action_size = env.action_space.n
self.opt = tf.train.AdamOptimizer(args.lr, use_locking=True)
print(self.state_size, self.action_size)
self.global_model = ActorCriticModel(self.state_size, self.action_size) # global network
self.global_model(tf.convert_to_tensor(np.random.random((1, self.state_size)), dtype=tf.float32))
The master agent will run the train function to instantiate and start each of the agents. The master agent handles the coordinating and supervision of each agent. Each of these agents will run asynchronously. (This is technically not true asynchrony because in Python, because of GIL (Global Interpreter Lock) a single python process cannot run threads in parallel (utilize multiple cores). It can however run them concurrently (context switch during I/O bound operations). We implement with threads for the sake of simplicity and clarity of example.)
def train(self):
if args.algorithm == 'random':
random_agent = RandomAgent(self.game_name, args.max_eps)
random_agent.run()
return
res_queue = Queue()
workers = [Worker(self.state_size,
self.action_size,
self.global_model,
self.opt, res_queue,
i, game_name=self.game_name,
save_dir=self.save_dir) for i in range(multiprocessing.cpu_count())]
for i, worker in enumerate(workers):
print("Starting worker {}".format(i))
worker.start()
moving_average_rewards = [] # record episode reward to plot
while True:
reward = res_queue.get()
if reward is not None:
moving_average_rewards.append(reward)
else:
break
[w.join() for w in workers]
plt.plot(moving_average_rewards)
plt.ylabel('Moving average ep reward')
plt.xlabel('Step')
plt.savefig(os.path.join(self.save_dir,
'{} Moving Average.png'.format(self.game_name)))
plt.show()
Memory
class. This class will simply give us the functionality to keep track of our actions, rewards, states that occur per step.
class Memory:
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
def store(self, state, action, reward):
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
def clear(self):
self.states = []
self.actions = []
self.rewards = []
Now, we get to the crux of the algorithm: the worker agent. The worker agent inherits from the threading class, and we override the run
method from Thread. This will allow us to achieve the first A in A3C, asynchronous. First we’ll begin by instantiating a local model and setting up the specific training parameters.
class Worker(threading.Thread):
# Set up global variables across different threads
global_episode = 0
# Moving average reward
global_moving_average_reward = 0
best_score = 0
save_lock = threading.Lock()
def __init__(self,
state_size,
action_size,
global_model,
opt,
result_queue,
idx,
game_name='CartPole-v0',
save_dir='/tmp'):
super(Worker, self).__init__()
self.state_size = state_size
self.action_size = action_size
self.result_queue = result_queue
self.global_model = global_model
self.opt = opt
self.local_model = ActorCriticModel(self.state_size, self.action_size)
self.worker_idx = idx
self.game_name = game_name
self.env = gym.make(self.game_name).unwrapped
self.save_dir = save_dir
self.ep_loss = 0.0
run
function. This will actually run our algorithm. We’ll run all the threads for a given global maximum number of episodes. This is where the third A, actor, of A3C comes into play. Our agent will “act” according to our policy function, becoming the actor while the action is judged by the “critic,” which is our value function. While this section of the code may seem dense, it really isn’t doing much. Within each episode, the code simply does:
args.update_freq
) or the agent has reached a terminal state (has died), then update the global model with gradients computed from local modeldef run(self):
total_step = 1
mem = Memory()
while Worker.global_episode < args.max_eps:
current_state = self.env.reset()
mem.clear()
ep_reward = 0.
ep_steps = 0
self.ep_loss = 0
time_count = 0
done = False
while not done:
logits, _ = self.local_model(
tf.convert_to_tensor(current_state[None, :],
dtype=tf.float32))
probs = tf.nn.softmax(logits)
action = np.random.choice(self.action_size, p=probs.numpy()[0])
new_state, reward, done, _ = self.env.step(action)
if done:
reward = -1
ep_reward += reward
mem.store(current_state, action, reward)
if time_count == args.update_freq or done:
# Calculate gradient wrt to local model. We do so by tracking the
# variables involved in computing the loss by using tf.GradientTape
with tf.GradientTape() as tape:
total_loss = self.compute_loss(done,
new_state,
mem,
args.gamma)
self.ep_loss += total_loss
# Calculate local gradients
grads = tape.gradient(total_loss, self.local_model.trainable_weights)
# Push local gradients to global model
self.opt.apply_gradients(zip(grads,
self.global_model.trainable_weights))
# Update local model with new weights
self.local_model.set_weights(self.global_model.get_weights())
mem.clear()
time_count = 0
if done: # done and print information
Worker.global_moving_average_reward = \
record(Worker.global_episode, ep_reward, self.worker_idx,
Worker.global_moving_average_reward, self.result_queue,
self.ep_loss, ep_steps)
# We must use a lock to save our model and to print to prevent data races.
if ep_reward > Worker.best_score:
with Worker.save_lock:
print("Saving best model to {}, "
"episode score: {}".format(self.save_dir, ep_reward))
self.global_model.save_weights(
os.path.join(self.save_dir,
'model_{}.h5'.format(self.game_name))
)
Worker.best_score = ep_reward
Worker.global_episode += 1
ep_steps += 1
time_count += 1
current_state = new_state
total_step += 1
self.result_queue.put(None)
def compute_loss(self,
done,
new_state,
memory,
gamma=0.99):
if done:
reward_sum = 0. # terminal
else:
reward_sum = self.local_model(
tf.convert_to_tensor(new_state[None, :],
dtype=tf.float32))[-1].numpy()[0]
# Get discounted rewards
discounted_rewards = []
for reward in memory.rewards[::-1]: # reverse buffer r
reward_sum = reward + gamma * reward_sum
discounted_rewards.append(reward_sum)
discounted_rewards.reverse()
logits, values = self.local_model(
tf.convert_to_tensor(np.vstack(memory.states),
dtype=tf.float32))
# Get our advantages
advantage = tf.convert_to_tensor(np.array(discounted_rewards)[:, None],
dtype=tf.float32) - values
# Value loss
value_loss = advantage ** 2
# Calculate our policy loss
actions_one_hot = tf.one_hot(memory.actions, self.action_size, dtype=tf.float32)
policy = tf.nn.softmax(logits)
entropy = tf.reduce_sum(policy * tf.log(policy + 1e-20), axis=1)
policy_loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=actions_one_hot,
logits=logits)
policy_loss *= tf.stop_gradient(advantage)
policy_loss -= 0.01 * entropy
total_loss = tf.reduce_mean((0.5 * value_loss + policy_loss))
return total_loss
That’s it! The worker agent will repeat the process of resetting the network parameters to all of those in the global network, and repeating the process of interacting with its environment, computing the loss, and then applying the gradients to the global network. Train your algorithm by running the command: python a3c_cartpole.py — train
.
def play(self):
env = gym.make(self.game_name).unwrapped
state = env.reset()
model = self.global_model
model_path = os.path.join(self.save_dir, 'model_{}.h5'.format(self.game_name))
print('Loading model from: {}'.format(model_path))
model.load_weights(model_path)
done = False
step_counter = 0
reward_sum = 0
try:
while not done:
env.render(mode='rgb_array')
policy, value = model(tf.convert_to_tensor(state[None, :], dtype=tf.float32))
policy = tf.nn.softmax(policy)
action = np.argmax(policy)
state, reward, done, _ = env.step(action)
reward_sum += reward
print("{}. Reward: {}, action: {}".format(step_counter, reward_sum, action))
step_counter += 1
except KeyboardInterrupt:
print("Received Keyboard Interrupt. Shutting down.")
finally:
env.close()
You can run this with the following command: python a3c_cartpole.py
once you’ve trained your model.
7월 31, 2018
—
By Raymond Yuan, Software Engineering Intern
In this tutorial we will learn how to train a model that is able to win at the simple game CartPole using deep reinforcement learning. We’ll use tf.keras and OpenAI’s gym to train an agent using a technique known as Asynchronous Advantage Actor Critic (A3C). Reinforcement learning has been receiving an enormous amount of attention, but what is it exact…