summaryrefslogtreecommitdiff
path: root/sol.py
diff options
context:
space:
mode:
Diffstat (limited to 'sol.py')
-rw-r--r--sol.py32
1 files changed, 24 insertions, 8 deletions
diff --git a/sol.py b/sol.py
index 7e3a0d4..fa1de97 100644
--- a/sol.py
+++ b/sol.py
@@ -135,12 +135,13 @@ def visualize(maze, Q):
else:
f = lambda s : s
- maxdir = argmax(Q[y][x])
+ Qev = Q.eval(x,y)
+ maxdir = argmax(Qev)
line3 += f("' " + ("^" if maxdir == NORTH else " ") + " ")
line5 += f(" " + ("v" if maxdir == SOUTH else " ") + " ")
- line1 += f(" %04.2f " % Q[y][x][NORTH])
- line2 += f("%s%04.2f %04.2f%s" % ("<" if maxdir == WEST else " ",Q[y][x][WEST], Q[y][x][EAST], ">" if maxdir == EAST else " "))
- line4 += f(" %04.2f " % Q[y][x][SOUTH])
+ line1 += f(" %04.2f " % Qev[NORTH])
+ line2 += f("%s%04.2f %04.2f%s" % ("<" if maxdir == WEST else " ",Qev[WEST], Qev[EAST], ">" if maxdir == EAST else " "))
+ line4 += f(" %04.2f " % Qev[SOUTH])
print(line3)
print(line1)
print(line2)
@@ -202,8 +203,23 @@ class World:
def is_final(self,s):
return self.maze[s[1]][s[0]] == 2
+
+# abstracts the Q-array. semantics of .eval(x,y) is `Q[y][x]`. semantics of .change((x,y),ac,diff) is `Q[y][x][ac]+=diff`
+class QArray:
+ def __init__(self):
+ self.Q = [ [ [0. for k in range(4)] for i in range(a.xlen) ] for j in range(a.ylen) ]
+
+ def eval(self,x,y = None):
+ if y==None: x,y = x
+
+ return self.Q[y][x]
+
+ def change(self, s, action, diff):
+ self.Q[s[1]][s[0]][action] += diff
+
+
a = World(maze, start)
-Q = [ [ [0. for k in range(4)] for i in range(a.xlen) ] for j in range(a.ylen) ]
+Q=QArray()
i=0
stopstate = -1
@@ -213,7 +229,7 @@ for i in range(1000000):
maxdiff=0.
for j in range(100):
# epsilon-greedy
- greedy = argmax(Q[s[1]][s[0]])
+ greedy = argmax(Q.eval(s))
rnd = random()
action = None
if rnd < epsilon:
@@ -223,8 +239,8 @@ for i in range(1000000):
r,ss = a.take_action(s[0],s[1], action)
#print ((r,ss))
- diff = alpha * (r + gamma * max( [ Q[ss[1]][ss[0]][aa] for aa in directions ] ) - Q[s[1]][s[0]][action])
- Q[s[1]][s[0]][action] += diff
+ diff = alpha * (r + gamma * max( [ Q.eval(ss)[aa] for aa in directions ] ) - Q.eval(s)[action])
+ Q.change(s,action,diff)
maxdiff = max(abs(diff),maxdiff)
total_reward += r
s = ss