CartPole Q Learning
Contents
CartPole Q Learning#
Title: CartPole Q Learning
Subtitle: Using Q Learning to Play cartpole
Date: 2018-10-21 13:20
Category: reinforcement-learning
Tags: rl, gym, qlearning, python
Authors: Varun Nayyar
I’ve been learning about reinforcement learning and wonder how it can be applied in my work. A common paradigm of machine learning/data science is to develop a model that can predict and feed those predictions into relatively simple code (or into a report for humans) to make decisions. For example, a bank might use a model to predict probability of default, and use that to suggest an interest rate. In many cases, the hard part is estimating risk, once that’s been worked out the decision is easy. Or a charity might look into the profile of it’s donors and combine with census data to see what can be done.
However, this is not always the case. Games are a fantastic example of needing to process information and then making a decision on how to react. Here the prediction problem is one aspect and optimizing for current state only isn’t necessarily a good idea, we need to optimize a bit longer term. Lot’s of control problems can also take reinforcement learning.
Q-Learning#
We’re going to dive into this problem using Q-Learning. This Blog is generally fantastic, but it also has a nice section on Q-Learning that will be a better reference so I can skip the maths. The basic intuition is that given a various state, we try and choose the action that has better long term reward. This is done by recording all playthroughs (in a smart way) and seeing what actions have what results in each state. In the early stages, the algo has no idea what it’s doing and will make choices randomly and see what happens. It will learn that certain actions in a state result tend to result in failure (walking into a goomba in mario) and certain actions result in some kind of reward (jumping into a goomba). More explicitly, each state’s value is equal to it’s reward and a discounted amount of expected reward
i.e. for a determinstic game with finite state, for action \(a\) in state \(s\) that moves it to state \(s'\) and gives reward \(r\) \begin{align} Q(s,a) = r + \gamma * max_{a} Q(s’,a) \end{align} where \(\gamma\) is the discount factor (higher discount means more short term gain, lower discount means more focus on long term gain). The above is called the Bellman update.
In Q learning, not only is the immediate reward is looked at, it also looks at the best possible reward from new state. This is kind of a recursive definition, so in reality when we play the game, states have a way of filtering back. An easy way to explain this would be a game called FrozenLake, where you’re on a 1d grid where going left leads to your death and going right leads to you surviving and obtaining a reward of 1. Let’s make the topology look like
Death Safe Safe You Safe Safe Reward
DSSYSSR
0123456
Now at the beginning, you’re in state 3 and action left and action right are both of unknown reward. Let’s say you go all the way to the right and get the reward. No in position 5, you know going right leads to a reward of 1. However restarting the game, position 3 still doesn’t know which way to go since 4 and 2 still have no info about Q. Let’s say you went right again and this time, entering state 5, you realise that the Q(5,right) = 1 while going left is still 0. Hence we know that Q(4, right) = \(\gamma\). Now playing again, Q(3, left/right) = 0, but if you go right again, you realise that Q(4, right) = \(\gamma\) which is better than left, so Q(3, right) = \(\gamma^2\). Now the game would only go right.
On the other hand, heading left towards your death results in no reward and thus, no information filters back. Setting the reward to -1 would have the same effect as driving the agent away.
However, most games are not deterministic and have elements of chance. Furthermore, instead of death, there might be a reward of 2 to the left we’ve never checked on. As a result, we add two things. Firstly, we don’t always choose the best action, we add an element of randommness using epsilon greedy (see more of the above blog) to choose a random action and we also update the Q(s,a) score a little more conservatively each time.
\begin{align} Q_{n+1}(s,a) = Q_n(s,a) * (1-\alpha) + \alpha [r + \gamma * max_{a} Q_n(s’,a)] \end{align}
For \(\alpha\) close to 1, new information changes the values we have more, and for small \(\alpha\), it has less impact. Those among you might realise this can be rewritten as
\begin{align} Q_{n+1}(s,a) &= \alpha \sum_n (1-\alpha)^n B_n \ B_n &= [r + \gamma * max_{a} Q_n(s’,a)] \end{align}
And noting that \(\alpha \sum_{n=0} (1-\alpha)^n = 1\), the above can be treated as a weighted expectation. Smaller values of alpha bring the value closer to an unweighted expectation. Hence we can think of choosing the best action in the new state as choosing the best expected action in the new state.
Implementing Q-Learning#
Q-Learning has two common implementations, tabular which works well for games with finite state and Neural Q Learning which uses a neural net to convert state into rewards for various actions. Neural Q learning has been extended to DQN which uses a variety of tricks to improve it’s performance (it’s quite brittle as is). In neural q learning, the \(\alpha \) updating isn’t necessary since the neural net will look at all Q(s,a) combinations as part of it’s training set. The neural net is being used as a function approximator and could be replaced by any function that can be used as an approximator (trees, k-NN etc), but in general neural nets tend to have very efficient code (GPUs) that allow an extra order of magnitude.
Before we start implementing, I’m going to define an interface, so I can reuse my agent code.
class QlearningInterface:
def __init(self, statedim, num_actions):
"""
It's only necessary to know the number of actions and dim of state
In a finite state approach, you can pass that through too
if you have more efficient code planned
"""
raise NotImplementedError
def __getitem__(self, item):
"""
item is (state, action) and it returns the Q value
"""
raise NotImplementedError
def __setitem__(self, key, value):
"""
key is (state, action) and value is the Bellman update
"""
raise NotImplementedError
def get_max(self, state):
"""
Returns max Q(s, a) for a fixed state, s
"""
raise NotImplementedError
def get_arg_max(self, state):
"""
Returns a that maximises Q(s, a) for a fixed state, s
"""
raise NotImplementedError
Also, let’s introduce the game and agent
Cartpole#
For the demonstration of Q-Learning, we’re going to use OpenAI’s Cartpole. The Agent here is actually quite straightforward, the magic is happening in the Q function approximator.
Cartpole state:
Cart Position (-2.4 to 2.4)
Cart velocity (-inf to inf)
Pole angle (-41.8 to 41.8), returned in radians. Failure at abs(angle) > 12
Pole tip velocity (-inf, inf)
Cartpole actions 0. Push left
Push right
For the tabular approach, I’m going to discretize the state to the best of my abilities
GAMMA = 0.9
import random
import os
import gym
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
def exploit(epsilon):
"""exploitation increases as epsilon decreases"""
return random.random() > epsilon
class Agent:
VID_DIR = './extra/video'
def __init__(self, Qstate, gamma=GAMMA):
self.env = gym.make('CartPole-v1')
self.gamma = gamma
self.Qstate = Qstate(statedim= self.env.observation_space.shape[0],
num_actions=self.env.action_space.n)
def get_action(self, state, epsilon):
if exploit(epsilon):
return self.Qstate.get_arg_max(state)
else:
return self.env.action_space.sample()
def train(self, num_epsiodes, initeps=1, finaleps=0.05):
"""
Simple linear drop for epsilon.
"""
epsdecay = (initeps - finaleps) / num_epsiodes
epsilon = initeps
test_window = int(num_epsiodes / 20)
num_steps = np.zeros(num_epsiodes)
eps_vals = np.zeros(num_epsiodes)
for i in range(num_epsiodes):
state = self.env.reset()
steps = 0
epsilon -= epsdecay
done = False
while not done:
action = self.get_action(state, epsilon=epsilon)
new_state, reward, done, info = self.env.step(action)
if not done:
# add the future reward * decay if we're still going
reward += self.gamma * self.Qstate.get_max(new_state)
steps += 1
self.Qstate[state, action] = reward
state = new_state
num_steps[i] = steps
eps_vals[i] = epsilon
if i and i % test_window == 0:
# every 5%
upp = i
low = upp - test_window
print(f"{i}: eps:{epsilon:.2f}, max: {np.max(num_steps[low:upp])}"
f" ave: {np.mean(num_steps[low:upp]):.2f}"
f" std: {np.std(num_steps[low:upp]):.2f}")
return num_steps, eps_vals
def run(self):
from gym.wrappers.monitor import Monitor
env = Monitor(self.env, Agent.VID_DIR, force=True)
done = False
steps = 0
state = env.reset()
while not done:
env.render(mode="rgb_array")
action = self.get_action(state, epsilon=0)
print(f"{steps} {state} {action}")
new_state, reward, done, info = env.step(action)
state = new_state
steps += 1
print(f"Numsteps: {steps}")
env.close()
The test window is just a progress update. The epsilon greedy algorithm just has a linear drop over training (this may not optimal, different functions or even a different approach to exploration may be better). The run section is just training with optimal action chosen and some pretty pictures.
Let’s define Qtable. Note the bins have been specified from sampling the env and plotting. I’ve also used knowledge that cart position and velocity is generally less important than pole tip and velocity. This has resulted in 4 * 4 * 29 * 8 bins which is 3.7k states.
class QTable:
ALPHA = 0.1
def __init__(self, statedim, num_actions):
from collections import defaultdict
self.qdict = defaultdict(lambda: [0] * num_actions)
# manually specify bins after inspection
self.bins = [0, 0, 0, 0]
self.bins[0] = np.array([-0.1, 0, 0.1])
self.bins[1] = np.array([-0.75, 0, 0.75])
thresh = np.pi / 180 * 12
self.bins[2] = np.linspace(-thresh, thresh, 30)[1:-1]
self.bins[3] = np.linspace(-1.7, 1.7, 9)[1:-1]
def _discretize(self, obs):
"""assume a tuple is already discretized"""
if not isinstance(obs, tuple):
state = [int(np.digitize(obs[i], self.bins[i])) for i in range(len(self.bins))]
return tuple(state)
return obs
def __getitem__(self, item):
try:
state, action = item
except ValueError:
# can't unpack, assume we've only been passed in a state
return self.qdict[self._discretize(item)]
else:
return self.qdict[self._discretize(state)][action]
def __setitem__(self, key, value):
"""
Only allowed to set a state, value pair.
Update is as per alpha value
"""
try:
state, action = key
except ValueError:
raise ValueError("can only set item on (state, action) pair")
else:
state = self._discretize(state)
self.qdict[state][action] = (self.qdict[state][action] * (1 - QTable.ALPHA)
+ value * QTable.ALPHA)
def get_max(self, state):
"""get maximum q score for a state over actions"""
state = self._discretize(state)
return max(self.qdict[state])
def get_arg_max(self, state):
"""which action gives max q score. If identical max q scrores exist,
this'll return a choice
"""
state = self._discretize(state)
mx = self.get_max(state)
return random.choice([i for i, j in enumerate(self.qdict[state]) if j == mx])
def show_video():
"""
Helper function to show video in an ipython notebook, courtesy of star ai
The render proecess is mega-janky so be warned.
"""
import glob, io, base64
from IPython.display import HTML
from IPython import display as ipythondisplay
# clear the json files out
[os.unlink(fl) for fl in glob.glob(os.path.join(Agent.VID_DIR, "*.json"))]
mp4list = glob.glob(os.path.join(Agent.VID_DIR, "*.mp4"))
if len(mp4list) > 0:
mp4 = mp4list[0]
video = io.open(mp4, 'r+b').read()
encoded = base64.b64encode(video)
ipythondisplay.display(HTML(data='''<video alt="test" autoplay
loop controls style="height: 400px;">
<source src="data:video/mp4;base64,{0}" type="video/mp4" />
</video>'''.format(encoded.decode('ascii'))))
else:
print("Could not find video")
# set gamma to be quite high since we want to stay upright as long as possible - this has quite
# a large effect on the reward.
# a lot of the growth is likely to be due to the epsilon value, not just the training
# using a faster epsilon growth makes this very obvious
# a more successful agent is slower to train due to the number of successes it obtains
# the longer it goes for. In this case, the env is limited to 200 timesteps
ag = Agent(QTable, gamma=0.99)
ns, eps = ag.train(25000)
# record,
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In [5], line 8
1 # set gamma to be quite high since we want to stay upright as long as possible - this has quite
2 # a large effect on the reward.
3 # a lot of the growth is likely to be due to the epsilon value, not just the training
4 # using a faster epsilon growth makes this very obvious
5 # a more successful agent is slower to train due to the number of successes it obtains
6 # the longer it goes for. In this case, the env is limited to 200 timesteps
7 ag = Agent(QTable, gamma=0.99)
----> 8 ns, eps = ag.train(25000)
Cell In [2], line 51, in Agent.train(self, num_epsiodes, initeps, finaleps)
49 while not done:
50 action = self.get_action(state, epsilon=epsilon)
---> 51 new_state, reward, done, info = self.env.step(action)
52 if not done:
53 # add the future reward * decay if we're still going
54 reward += self.gamma * self.Qstate.get_max(new_state)
ValueError: too many values to unpack (expected 4)
ag.run()
0 [ 0.01739771 -0.02898541 0.02717045 -0.03584134] 0
1 [ 0.016818 -0.22448624 0.02645362 0.2652888 ] 0
2 [ 0.01232828 -0.41997557 0.0317594 0.56619666] 1
3 [ 0.00392876 -0.22531322 0.04308333 0.28368608] 1
4 [-0.0005775 -0.0308314 0.04875705 0.00489663] 1
5 [-0.00119413 0.16355863 0.04885499 -0.272013 ] 1
6 [ 0.00207704 0.35795064 0.04341473 -0.54889545] 1
7 [ 0.00923606 0.5524367 0.03243682 -0.82758973] 0
8 [ 0.02028479 0.3568866 0.01588502 -0.52488413] 0
9 [ 0.02742252 0.16154475 0.00538734 -0.22723832] 1
10 [ 0.03065342 0.3565893 0.00084257 -0.51821702] 0
11 [ 0.0377852 0.16145549 -0.00952177 -0.2252687 ] 0
12 [ 0.04101431 -0.03352909 -0.01402714 0.06439555] 1
13 [ 0.04034373 0.16179114 -0.01273923 -0.2326798 ] 0
14 [ 0.04357956 -0.03314649 -0.01739282 0.05595768] 1
15 [ 0.04291663 0.16222047 -0.01627367 -0.24216167] 0
16 [ 0.04616104 -0.03266529 -0.0211169 0.04534401] 1
17 [ 0.04550773 0.162753 -0.02021002 -0.25392598] 0
18 [ 0.04876279 -0.03207464 -0.02528854 0.03231446] 1
19 [ 0.0481213 0.16340066 -0.02464225 -0.26823878] 1
20 [ 0.05138931 0.35886547 -0.03000703 -0.56859106] 0
21 [ 0.05856662 0.16417696 -0.04137885 -0.28551051] 1
22 [ 0.06185016 0.35986388 -0.04708906 -0.59095154] 0
23 [ 0.06904744 0.16543174 -0.05890809 -0.31346569] 0
24 [ 0.07235607 -0.02880372 -0.06517741 -0.03992684] 0
25 [ 0.07178 -0.22293342 -0.06597594 0.23150083] 0
26 [ 0.06732133 -0.41705363 -0.06134593 0.50266479] 1
27 [ 0.05898026 -0.2211231 -0.05129263 0.19129743] 1
28 [ 0.05455779 -0.02530632 -0.04746668 -0.11711486] 0
29 [ 0.05405167 -0.21971716 -0.04980898 0.16022297] 1
30 [ 0.04965732 -0.02391881 -0.04660452 -0.14774808] 0
31 [ 0.04917895 -0.21834347 -0.04955948 0.12987548] 1
32 [ 0.04481208 -0.0225479 -0.04696197 -0.1780218 ] 0
33 [ 0.04436112 -0.21696744 -0.05052241 0.09948424] 1
34 [ 0.04002177 -0.02115918 -0.04853272 -0.20870064] 0
35 [ 0.03959859 -0.21555478 -0.05270674 0.06828639] 1
36 [ 0.03528749 -0.01971835 -0.05134101 -0.24054894] 0
37 [ 0.03489313 -0.21407073 -0.05615199 0.03550788] 1
38 [ 0.03061171 -0.01819038 -0.05544183 -0.27434956] 0
39 [ 0.0302479 -0.21247926 -0.06092882 0.00034442] 1
40 [ 0.02599832 -0.01653882 -0.06092193 -0.31092286] 0
41 [ 0.02566754 -0.21074229 -0.06714039 -0.03805804] 0
42 [ 0.0214527 -0.40484041 -0.06790155 0.23270872] 0
43 [ 0.01335589 -0.59892975 -0.06324738 0.50322411] 1
44 [ 0.00137729 -0.40297605 -0.05318289 0.19129893] 1
45 [-0.00668223 -0.20713521 -0.04935692 -0.1176754 ] 0
46 [-0.01082493 -0.4015165 -0.05171042 0.15903643] 0
47 [-0.01885526 -0.59586146 -0.04852969 0.4349681 ] 1
48 [-0.03077249 -0.40008728 -0.03983033 0.12739044] 0
49 [-0.03877424 -0.59461667 -0.03728252 0.40724607] 0
50 [-0.05066657 -0.78919068 -0.0291376 0.68794575] 0
51 [-0.06645038 -0.98389635 -0.01537869 0.97131504] 1
52 [-0.08612831 -0.78857142 0.00404761 0.67384113] 1
53 [-0.10189974 -0.59350595 0.01752444 0.38243533] 0
54 [-0.11376986 -0.78887229 0.02517314 0.68059175] 1
55 [-0.1295473 -0.59410886 0.03878498 0.39593929] 0
56 [-0.14142948 -0.78975903 0.04670376 0.70059377] 1
57 [-0.15722466 -0.59531456 0.06071564 0.42297102] 0
58 [-0.16913095 -0.79124175 0.06917506 0.7341597 ] 0
59 [-0.18495579 -0.98724775 0.08385825 1.04778679] 1
60 [-0.20470074 -0.79333265 0.10481399 0.78256172] 1
61 [-0.2205674 -0.59979539 0.12046522 0.52460795] 1
62 [-0.2325633 -0.40655636 0.13095738 0.27218122] 1
63 [-0.24069443 -0.21352254 0.13640101 0.02350134] 1
64 [-0.24496488 -0.02059355 0.13687103 -0.22322654] 1
65 [-0.24537675 0.17233372 0.1324065 -0.46979643] 1
66 [-0.24193008 0.36536114 0.12301057 -0.71798936] 1
67 [-0.23462286 0.55858553 0.10865079 -0.96956189] 0
68 [-0.22345115 0.36218603 0.08925955 -0.64481929] 1
69 [-0.21620743 0.55595816 0.07636316 -0.90811289] 0
70 [-0.20508826 0.35989017 0.05820091 -0.59243905] 0
71 [-0.19789046 0.16400383 0.04635212 -0.28200564] 0
72 [-0.19461038 -0.03174758 0.04071201 0.02492874] 0
73 [-0.19524533 -0.22742901 0.04121059 0.3301737 ] 0
74 [-0.19979391 -0.42311263 0.04781406 0.6355624 ] 0
75 [-0.20825617 -0.6188677 0.06052531 0.9429114 ] 1
76 [-0.22063352 -0.42461119 0.07938354 0.66984377] 1
77 [-0.22912574 -0.23067746 0.09278041 0.40317354] 1
78 [-0.23373929 -0.03698545 0.10084388 0.14112346] 0
79 [-0.234479 -0.23339625 0.10366635 0.46384076] 1
80 [-0.23914693 -0.03988031 0.11294317 0.20554853] 1
81 [-0.23994453 0.15346056 0.11705414 -0.04948029] 1
82 [-0.23687532 0.34672648 0.11606453 -0.30306 ] 1
83 [-0.22994079 0.54001951 0.11000333 -0.55700244] 1
84 [-0.2191404 0.73343921 0.09886328 -0.81310257] 0
85 [-0.20447162 0.53711225 0.08260123 -0.49103164] 1
86 [-0.19372937 0.7309778 0.0727806 -0.75658133] 1
87 [-0.17910982 0.92502519 0.05764897 -1.02550323] 0
88 [-0.16060931 0.72918498 0.03713891 -0.71529152] 1
89 [-0.14602561 0.9237737 0.02283308 -0.99605699] 0
90 [-0.12755014 0.72835396 0.00291194 -0.69629144] 0
91 [-0.11298306 0.53319174 -0.01101389 -0.40269325] 1
92 [-0.10231923 0.72846816 -0.01906776 -0.69882813] 0
93 [-0.08774986 0.5336157 -0.03304432 -0.41220816] 1
94 [-0.07707755 0.72919011 -0.04128848 -0.71512289] 0
95 [-0.06249375 0.53466328 -0.05559094 -0.43571669] 0
96 [-0.05180048 0.34037055 -0.06430527 -0.16106311] 0
97 [-0.04499307 0.14622532 -0.06752654 0.11066009] 0
98 [-0.04206856 -0.04786729 -0.06531333 0.38129777] 0
99 [-0.04302591 -0.24200396 -0.05768738 0.65269341] 1
100 [-0.04786599 -0.04612813 -0.04463351 0.34241772] 0
101 [-0.04878855 -0.24058759 -0.03778516 0.62069853] 1
102 [-0.0536003 -0.04495889 -0.02537119 0.31635874] 1
103 [-0.05449948 0.15051508 -0.01904401 0.01578381] 1
104 [-0.05148918 0.3459049 -0.01872833 -0.28284636] 0
105 [-0.04457108 0.15105501 -0.02438526 0.0038714 ] 1
106 [-0.04154998 0.34651804 -0.02430783 -0.29640453] 1
107 [-0.03461962 0.54197794 -0.03023592 -0.59665356] 0
108 [-0.02378006 0.34729187 -0.042169 -0.31364594] 0
109 [-0.01683422 0.15279521 -0.04844191 -0.03455432] 0
110 [-0.01377832 -0.0415998 -0.049133 0.24245979] 0
111 [-0.01461032 -0.23598676 -0.04428381 0.51924905] 1
112 [-0.01933005 -0.04027025 -0.03389882 0.21294692] 1
113 [-0.02013546 0.15531955 -0.02963989 -0.09023357] 1
114 [-0.01702906 0.35085353 -0.03144456 -0.39211857] 1
115 [-0.01001199 0.54640728 -0.03928693 -0.69454737] 0
116 [ 0.00091615 0.35185166 -0.05317788 -0.41448651] 0
117 [ 0.00795318 0.15752218 -0.06146761 -0.13903128] 0
118 [ 0.01110363 -0.03666799 -0.06424823 0.13364378] 1
119 [ 0.01037027 0.15931261 -0.06157536 -0.17859659] 0
120 [ 0.01355652 -0.0348766 -0.06514729 0.09404348] 0
121 [ 0.01285899 -0.22900725 -0.06326642 0.36548196] 1
122 [ 0.00827884 -0.03304594 -0.05595678 0.05354015] 1
123 [ 0.00761793 0.16283186 -0.05488598 -0.25625937] 0
124 [ 0.01087456 -0.03146526 -0.06001116 0.01861905] 1
125 [ 0.01024526 0.16446369 -0.05963878 -0.29237802] 0
126 [ 0.01353453 -0.02975947 -0.06548634 -0.01908465] 0
127 [ 0.01293934 -0.22388413 -0.06586804 0.25223897] 0
128 [ 0.00846166 -0.41800671 -0.06082326 0.52344021] 1
129 [ 1.01524929e-04 -2.22083835e-01 -5.03544524e-02 2.12228316e-01] 1
130 [-0.00434015 -0.02627946 -0.04610989 -0.09590393] 0
131 [-0.00486574 -0.22071125 -0.04802796 0.18188227] 0
132 [-0.00927997 -0.41511423 -0.04439032 0.4590357 ] 1
133 [-0.01758225 -0.21939382 -0.03520961 0.15269781] 1
134 [-0.02197013 -0.02378586 -0.03215565 -0.15088175] 0
135 [-0.02244584 -0.21843296 -0.03517328 0.13148584] 1
136 [-0.0268145 -0.02282529 -0.03254357 -0.17208301] 0
137 [-0.02727101 -0.21746671 -0.03598523 0.1101584 ] 1
138 [-0.03162034 -0.02184807 -0.03378206 -0.19365696] 0
139 [-0.0320573 -0.2164709 -0.0376552 0.08818068] 0
140 [-0.03638672 -0.41103344 -0.03589159 0.3687496 ] 1
141 [-0.04460739 -0.21542038 -0.02851659 0.06496912] 1
142 [-0.0489158 -0.01990143 -0.02721721 -0.23657281] 0
143 [-0.04931383 -0.21462417 -0.03194867 0.04740226] 1
144 [-0.05360631 -0.01905901 -0.03100062 -0.25518709] 0
145 [-0.05398749 -0.21372494 -0.03610436 0.02755878] 1
146 [-0.05826199 -0.01810433 -0.03555319 -0.27629329] 0
147 [-0.05862408 -0.21270148 -0.04107905 0.00496754] 0
148 [-0.06287811 -0.40721096 -0.0409797 0.28441177] 0
149 [-0.07102233 -0.6017252 -0.03529147 0.56389352] 1
150 [-0.08305683 -0.4061263 -0.0240136 0.26030432] 1
151 [-0.09117936 -0.21066993 -0.01880751 -0.03985503] 0
152 [-0.09539275 -0.4055172 -0.01960461 0.24683511] 0
153 [-0.1035031 -0.60035375 -0.01466791 0.53327039] 0
154 [-0.11551017 -0.79526638 -0.0040025 0.82129566] 0
155 [-0.1314155 -0.99033333 0.01242341 1.11271702] 0
156 [-0.15122217 -1.18561622 0.03467775 1.40927114] 1
157 [-0.17493449 -0.99094115 0.06286317 1.1276273 ] 1
158 [-0.19475331 -0.79669653 0.08541572 0.8553055 ] 1
159 [-0.21068725 -0.60283597 0.10252183 0.5906565 ] 1
160 [-0.22274396 -0.40928755 0.11433496 0.33194693] 1
161 [-0.23092972 -0.21596293 0.1209739 0.07739464] 1
162 [-0.23524897 -0.02276421 0.12252179 -0.17480583] 1
163 [-0.23570426 0.17041065 0.11902568 -0.42646459] 1
164 [-0.23229605 0.36366336 0.11049638 -0.67938172] 1
165 [-0.22502278 0.55709105 0.09690875 -0.93533395] 0
166 [-0.21388096 0.36080501 0.07820207 -0.61383976] 1
167 [-0.20666486 0.55475211 0.06592528 -0.8809033 ] 0
168 [-0.19556982 0.35879953 0.04830721 -0.56824493] 1
169 [-0.18839382 0.5532118 0.03694231 -0.84532635] 1
170 [-0.17732959 0.74781075 0.02003578 -1.12616705] 0
171 [-0.16237337 0.55243206 -0.00248756 -0.82726771] 0
172 [-0.15132473 0.35734421 -0.01903291 -0.53536818] 0
173 [-0.14417785 0.162495 -0.02974028 -0.24874264] 0
174 [-0.14092795 -0.03218988 -0.03471513 0.0344132 ] 1
175 [-0.14157175 0.16341225 -0.03402686 -0.26901742] 0
176 [-0.1383035 -0.031208 -0.03940721 0.012742 ] 1
177 [-0.13892766 0.16445629 -0.03915237 -0.2921093 ] 1
178 [-0.13563854 0.36011398 -0.04499456 -0.59687878] 0
179 [-0.12843626 0.1656496 -0.05693213 -0.31870136] 0
180 [-0.12512326 -0.02861724 -0.06330616 -0.04450224] 0
181 [-0.12569561 -0.22277697 -0.06419621 0.22755438] 0
182 [-0.13015115 -0.41692556 -0.05964512 0.49931666] 1
183 [-0.13848966 -0.22101567 -0.04965879 0.18844912] 1
184 [-0.14290997 -0.02521973 -0.0458898 -0.11947643] 1
185 [-0.14341437 0.17052864 -0.04827933 -0.42627674] 0
186 [-0.14000379 -0.02387741 -0.05680487 -0.14919606] 0
187 [-0.14048134 -0.21814187 -0.05978879 0.12503896] 0
188 [-0.14484418 -0.41235857 -0.05728801 0.39827604] 0
189 [-0.15309135 -0.60662299 -0.04932249 0.67236116] 1
190 [-0.16522381 -0.41085139 -0.03587526 0.36456595] 1
191 [-0.17344084 -0.21523844 -0.02858395 0.06079026] 1
192 [-0.17774561 -0.01971857 -0.02736814 -0.24077228] 1
193 [-0.17813998 0.17578343 -0.03218359 -0.54196069] 0
194 [-0.17462431 -0.01887175 -0.0430228 -0.25958945] 1
195 [-0.17500175 0.17683713 -0.04821459 -0.56552593] 0
196 [-0.171465 -0.01757642 -0.05952511 -0.2884141 ] 0
197 [-0.17181653 -0.21180124 -0.06529339 -0.0150829 ] 0
198 [-0.17605256 -0.40592899 -0.06559505 0.25630562] 0
199 [-0.18417114 -0.60005608 -0.06046893 0.52759891] 1
Numsteps: 200
show_video()
This approach is heavily dependent on my discretization and the gamma and alpha values. Also depending on the epsilon change, you can get much better behaviour with a lot of tweaking.
Neural Q Learning#
In this case, we use a neural net to try and predict action from state. The basic topology is we feed in the state and the neural net will predict the Q for the various actions (4 input layers and 2 output layers in this case). To do the training we feed in state, action and bellman value. The neural net will simply mask the action used to calculate the loss function.
More precisely, we predict the Q for the (s,a) pair and Q for all a for (s’, a). We then fit Q on (s,a) with the bellman update (\(\gamma * max_a Q(s',a)\)) . This means each step has our neural net running a fit, which makes this quite slow compared to tabular version.
The chosen topology has 2 hidden layers of 20 nodes each, relu activations, and we combine to the 2 action states with a linear layer. This neural net is quite brittle and the hyper parameters have been copied from an existing implementation. Intuitively, we’re updating the neural net each observation and this means that our gradient climbs seem to be changing constantly
Note, Keras expects everything in form of (n_obs, obs_dim), so a state of dim 4 should be passed in as (1,4) matrix. I’ve created some convenient interface code, since in neural q, we only fit a single observation and predict a single state at a time. Extensions with DQN will require some extra changes.
def to_row(array):
if array.ndim == 1:
return array.reshape(1, len(array))
return array
HIDDEN_NODES = 20
class NeuralQ:
def __init__(self, statedim, num_actions, hidden_nodes=HIDDEN_NODES):
from keras import Model
from keras.layers import Dense, Input, Dot
self.state_dim = statedim
self.action_dim = num_actions
state = Input((self.state_dim,))
h1 = Dense(hidden_nodes, activation="relu")(state)
h2 = Dense(hidden_nodes, activation="relu")(h1)
qvals = Dense(self.action_dim, activation="linear")(h2)
# this is the qval model, however we're going to add a mask to train
# on. This submodel will be trained too and can be used later
self.qvalmod = Model(inputs=state, outputs=qvals)
# now mask with chosen action
action_in = Input((self.action_dim,))
# the dot product with axis=1 will give us what we need
max_sel = Dot(1, name='max_sel')([qvals, action_in])
# input is state and action chosen, output is q-val for given actionn
# we build the model to train with
model = Model(inputs=[state, action_in], outputs=max_sel)
model.compile("adam", loss='mean_squared_error')
self.model = model
def fit_single(self, state, action, output):
"""
Args:
state (np.array): (self.statedim)
action (int):
output (float):
"""
state = to_row(state)
a = np.zeros(self.action_dim)
a[action] = 1
action = to_row(a)
output = np.array([output])
self.model.fit([state, action], output, verbose=False)
def pred_qval(self, state):
return self.qvalmod.predict(to_row(np.array(state)))
def __getitem__(self, item):
try:
state, action = item
except ValueError:
# can't unpack, assume we've only been passed in a state
v1 = self.pred_qval(item)
return v1.reshape(self.action_dim)
else:
v1 = self.pred_qval(state)
return v1.reshape(self.action_dim)[action]
def __setitem__(self, key, value):
"""
Only allowed to set a state, value pair.
Update is as per alpha value
"""
try:
state, action = key
except ValueError:
raise ValueError("can only set item on (state, action) pair")
else:
self.fit_single(state, action, value)
def get_max(self, state):
"""get maximum q score for a state over actions"""
qvals = self[state]
return np.max(qvals)
def get_arg_max(self, state):
"""which action gives max q score. If identical max q scrores exist,
this'll return a choice
"""
qvals = self[state]
mx = np.max(qvals)
return random.choice([i for i, j in enumerate(qvals) if j == mx])
nag = Agent(NeuralQ, gamma=0.6)
ns, eps= nag.train(10000)
Using TensorFlow backend.
500: eps:0.95, max: 103.0 ave: 22.02 std: 12.57
1000: eps:0.90, max: 104.0 ave: 23.72 std: 13.80
1500: eps:0.86, max: 126.0 ave: 26.31 std: 16.73
2000: eps:0.81, max: 110.0 ave: 27.43 std: 18.50
2500: eps:0.76, max: 134.0 ave: 28.84 std: 20.50
3000: eps:0.71, max: 199.0 ave: 36.84 std: 29.18
3500: eps:0.67, max: 199.0 ave: 43.02 std: 36.96
4000: eps:0.62, max: 199.0 ave: 43.03 std: 39.05
4500: eps:0.57, max: 199.0 ave: 45.78 std: 41.47
5000: eps:0.52, max: 199.0 ave: 52.35 std: 46.85
5500: eps:0.48, max: 199.0 ave: 50.70 std: 48.51
6000: eps:0.43, max: 199.0 ave: 49.47 std: 47.86
6500: eps:0.38, max: 199.0 ave: 47.80 std: 50.16
7000: eps:0.33, max: 199.0 ave: 61.79 std: 55.29
7500: eps:0.29, max: 199.0 ave: 69.46 std: 61.90
8000: eps:0.24, max: 199.0 ave: 70.47 std: 66.59
8500: eps:0.19, max: 199.0 ave: 81.92 std: 67.67
9000: eps:0.14, max: 199.0 ave: 89.98 std: 76.12
9500: eps:0.10, max: 199.0 ave: 85.25 std: 66.59
nag.run()
0 [-0.0418108 -0.02693399 -0.03569369 0.04776645] 1
1 [-0.04234948 0.16868112 -0.03473836 -0.25596104] 1
2 [-0.03897586 0.36428138 -0.03985759 -0.55939545] 1
3 [-0.03169023 0.55993945 -0.05104549 -0.86436448] 1
4 [-0.02049144 0.75571771 -0.06833278 -1.17265022] 1
5 [-0.00537709 0.95165816 -0.09178579 -1.48594919] 0
6 [ 0.01365607 0.75776691 -0.12150477 -1.2232837 ] 0
7 [ 0.02881141 0.56440132 -0.14597045 -0.9710083 ] 0
8 [ 0.04009944 0.371508 -0.16539061 -0.72750846] 0
9 [ 0.0475296 0.17901175 -0.17994078 -0.49110771] 0
10 [ 0.05110983 -0.01317662 -0.18976294 -0.26009726] 0
11 [ 0.0508463 -0.20515461 -0.19496488 -0.03275564] 0
12 [ 0.04674321 -0.39702408 -0.19561999 0.19263873] 0
13 [ 0.03880273 -0.5888882 -0.19176722 0.41779971] 0
14 [ 0.02702496 -0.78084909 -0.18341122 0.64442813] 0
15 [ 0.01140798 -0.97300544 -0.17052266 0.87420709] 0
16 [-0.00805213 -1.16545001 -0.15303852 1.10879612] 0
17 [-0.03136113 -1.35826625 -0.1308626 1.34982227] 1
18 [-0.05852645 -1.16176567 -0.10386615 1.01922772] 1
19 [-0.08176177 -0.96542433 -0.0834816 0.69582112] 1
20 [-0.10107025 -0.76924988 -0.06956518 0.37806914] 1
21 [-0.11645525 -0.57321247 -0.06200379 0.06428859] 1
22 [-0.1279195 -0.37725887 -0.06071802 -0.24729442] 1
23 [-0.13546468 -0.18132469 -0.06566391 -0.55849424] 1
24 [-0.13909117 0.01465457 -0.07683379 -0.87112091] 1
25 [-0.13879808 0.21073274 -0.09425621 -1.18693736] 0
26 [-0.13458343 0.01695081 -0.11799496 -0.92522502] 0
27 [-0.13424441 -0.17639685 -0.13649946 -0.67183023] 0
28 [-0.13777235 -0.36938387 -0.14993606 -0.42504867] 0
29 [-0.14516002 -0.56209929 -0.15843704 -0.18313498] 0
30 [-0.15640201 -0.75464132 -0.16209974 0.05567485] 0
31 [-0.17149484 -0.94711269 -0.16098624 0.29314825] 0
32 [-0.19043709 -1.13961708 -0.15512328 0.53104533] 0
33 [-0.21322943 -1.33225589 -0.14450237 0.77111032] 0
34 [-0.23987455 -1.5251251 -0.12908016 1.01506291] 0
35 [-0.27037705 -1.71831152 -0.1087789 1.26458752] 1
36 [-0.30474328 -1.52198049 -0.08348715 0.93991367] 1
37 [-0.33518289 -1.32583845 -0.06468888 0.62220863] 1
38 [-0.36169966 -1.12987568 -0.05224471 0.30987402] 1
39 [-0.38429717 -0.93404979 -0.04604723 0.00118303] 1
40 [-0.40297817 -0.73829874 -0.04602357 -0.30566542] 1
41 [-0.41774414 -0.54255219 -0.05213688 -0.61250049] 1
42 [-0.42859519 -0.34674186 -0.06438688 -0.92113818] 1
43 [-0.43553003 -0.15081167 -0.08280965 -1.23334114] 0
44 [-0.43854626 -0.34477708 -0.10747647 -0.96770825] 0
45 [-0.4454418 -0.53830456 -0.12683064 -0.7106293 ] 0
46 [-0.45620789 -0.7314632 -0.14104322 -0.4604069 ] 0
47 [-0.47083716 -0.92433929 -0.15025136 -0.21529506] 0
48 [-0.48932394 -1.11702956 -0.15455726 0.02647604] 0
49 [-0.51166453 -1.30963618 -0.15402774 0.26668297] 0
50 [-0.53785726 -1.50226282 -0.14869408 0.5070969 ] 0
51 [-0.56790251 -1.6950112 -0.13855214 0.74947398] 0
52 [-0.60180274 -1.88797781 -0.12356266 0.99554596] 0
53 [-0.63956229 -2.08125003 -0.10365174 1.24700893] 0
54 [-0.68118729 -2.27490137 -0.07871157 1.50550786] 0
55 [-0.72668532 -2.46898521 -0.04860141 1.77261494] 1
56 [-0.77606502 -2.27335005 -0.01314911 1.46522553] 1
57 [-0.82153203 -2.07806956 0.0161554 1.16846439] 1
58 [-0.86309342 -1.88316146 0.03952469 0.88088987] 1
59 [-0.90075665 -1.68859808 0.05714249 0.60088996] 1
60 [-0.93452861 -1.4943201 0.06916029 0.32673939] 1
61 [-0.96441501 -1.30024746 0.07569507 0.05664329] 1
62 [-0.99041996 -1.10628787 0.07682794 -0.21122988] 1
63 [-1.01254572 -0.91234367 0.07260334 -0.47872283] 1
64 [-1.03079259 -0.71831779 0.06302888 -0.74766829] 1
65 [-1.04515895 -0.5241194 0.04807552 -1.01986924] 1
66 [-1.05564133 -0.32966995 0.02767813 -1.29707765] 1
67 [-1.06223473 -0.13491023 0.00173658 -1.580969 ] 1
68 [-1.06493294 0.06019101 -0.0298828 -1.87310986] 0
69 [-1.06372912 -0.13459219 -0.067345 -1.58984975] 0
70 [-1.06642096 -0.32885281 -0.09914199 -1.31890382] 0
71 [-1.07299802 -0.52259131 -0.12552007 -1.05882312] 0
72 [-1.08344984 -0.71584727 -0.14669653 -0.80802586] 0
73 [-1.09776679 -0.90868707 -0.16285705 -0.56484727] 0
74 [-1.11594053 -1.10119465 -0.17415399 -0.32757536] 0
75 [-1.13796442 -1.29346492 -0.1807055 -0.09447537] 0
76 [-1.16383372 -1.48559881 -0.18259501 0.13619398] 0
77 [-1.1935457 -1.67769972 -0.17987113 0.36616994] 0
78 [-1.22709969 -1.86987069 -0.17254773 0.59717847] 0
79 [-1.26449711 -2.06221175 -0.16060416 0.83092812] 0
80 [-1.30574134 -2.25481712 -0.1439856 1.06910329] 0
81 [-1.35083768 -2.44777185 -0.12260353 1.31335468] 0
82 [-1.39979312 -2.64114723 -0.09633644 1.56528522] 0
83 [-1.45261607 -2.8349948 -0.06503073 1.82642921] 0
84 [-1.50931596 -3.02933819 -0.02850215 2.09822253] 1
85 [-1.56990273 -2.83394171 0.0134623 1.79686836] 0
86 [-1.62658156 -3.02921171 0.04939967 2.09370462] 1
87 [-1.68716579 -2.83462061 0.09127376 1.81669164] 1
88 [-1.74385821 -2.64062485 0.12760759 1.55370652] 1
89 [-1.7966707 -2.44724262 0.15868172 1.30340661] 1
90 [-1.84561556 -2.25444872 0.18474986 1.06430592] 1
91 [-1.89070453 -2.0621882 0.20603598 0.83483088] 1
Numsteps: 92
show_video()
Conclusion#
Hope you’ve enjoyed this brief intro to Q Learning.
Neural Q learning is quite brittle and unstable (these things are addressed in an upgrade called DQN) compared to tabular Q in this case. We can see the neural Q get quite good, then get worse, not to mention how much slower it is. The problem with tabular q is memory size which can grow out of control in more complex problems, especially ones with limited linearity and of course, when we need to interpret from pixels/video as opposed to a much nicer number based interface.
Acknowledgements#
Thanks to starai (sans a website) for providing forums and lectures to kickstart this. Alexander Long for providing some code skeletons to work with.