From 2a174deb128de6f3221ad1329361d3126e3d90d2 Mon Sep 17 00:00:00 2001 From: Florian Jung Date: Tue, 5 Jan 2016 17:26:08 +0100 Subject: abstract Q array through object --- sol.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/sol.py b/sol.py index 7e3a0d4..fa1de97 100644 --- a/sol.py +++ b/sol.py @@ -135,12 +135,13 @@ def visualize(maze, Q): else: f = lambda s : s - maxdir = argmax(Q[y][x]) + Qev = Q.eval(x,y) + maxdir = argmax(Qev) line3 += f("' " + ("^" if maxdir == NORTH else " ") + " ") line5 += f(" " + ("v" if maxdir == SOUTH else " ") + " ") - line1 += f(" %04.2f " % Q[y][x][NORTH]) - line2 += f("%s%04.2f %04.2f%s" % ("<" if maxdir == WEST else " ",Q[y][x][WEST], Q[y][x][EAST], ">" if maxdir == EAST else " ")) - line4 += f(" %04.2f " % Q[y][x][SOUTH]) + line1 += f(" %04.2f " % Qev[NORTH]) + line2 += f("%s%04.2f %04.2f%s" % ("<" if maxdir == WEST else " ",Qev[WEST], Qev[EAST], ">" if maxdir == EAST else " ")) + line4 += f(" %04.2f " % Qev[SOUTH]) print(line3) print(line1) print(line2) @@ -202,8 +203,23 @@ class World: def is_final(self,s): return self.maze[s[1]][s[0]] == 2 + +# abstracts the Q-array. semantics of .eval(x,y) is `Q[y][x]`. semantics of .change((x,y),ac,diff) is `Q[y][x][ac]+=diff` +class QArray: + def __init__(self): + self.Q = [ [ [0. for k in range(4)] for i in range(a.xlen) ] for j in range(a.ylen) ] + + def eval(self,x,y = None): + if y==None: x,y = x + + return self.Q[y][x] + + def change(self, s, action, diff): + self.Q[s[1]][s[0]][action] += diff + + a = World(maze, start) -Q = [ [ [0. for k in range(4)] for i in range(a.xlen) ] for j in range(a.ylen) ] +Q=QArray() i=0 stopstate = -1 @@ -213,7 +229,7 @@ for i in range(1000000): maxdiff=0. for j in range(100): # epsilon-greedy - greedy = argmax(Q[s[1]][s[0]]) + greedy = argmax(Q.eval(s)) rnd = random() action = None if rnd < epsilon: @@ -223,8 +239,8 @@ for i in range(1000000): r,ss = a.take_action(s[0],s[1], action) #print ((r,ss)) - diff = alpha * (r + gamma * max( [ Q[ss[1]][ss[0]][aa] for aa in directions ] ) - Q[s[1]][s[0]][action]) - Q[s[1]][s[0]][action] += diff + diff = alpha * (r + gamma * max( [ Q.eval(ss)[aa] for aa in directions ] ) - Q.eval(s)[action]) + Q.change(s,action,diff) maxdiff = max(abs(diff),maxdiff) total_reward += r s = ss -- cgit v1.2.3