From 4b8bd673ff44319c7f2d955e07937b2fed805326 Mon Sep 17 00:00:00 2001 From: Florian Jung Date: Tue, 5 Jan 2016 16:59:27 +0100 Subject: cleanup --- sol.py | 58 ++++++++-------------------------------------------------- 1 file changed, 8 insertions(+), 50 deletions(-) diff --git a/sol.py b/sol.py index f98421d..7e3a0d4 100644 --- a/sol.py +++ b/sol.py @@ -10,9 +10,7 @@ maze = [ [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 2] ] start=(1,1) -theta = None -theta_1 = 0.00001 -theta_2 = 0.01 +theta = 0.01 gamma=0.9 # discount epsilon = 0.1 epsilon_reduction = 1 @@ -53,8 +51,8 @@ if arg['-h']: print(" MODE: -1 / --policy-evaluation or\n" + " -2 / --q-learning\n" + " OPTIONS: --theta NUM # convergence threshold\n" + - " # default: %f / %f for -1 / -2\n" % (theta_1, theta_2) + - " --gamma NUM # learning discount for value iteration\n" + + " # default: %f\n" % theta + + " --gamma NUM # learning discount\n" + " # default: %f\n" % gamma + " --alpha NUM # learning rate for q-learning\n" + " # default: %f\n" % alpha + @@ -78,7 +76,7 @@ if arg['-q'] or arg['--quiet']: visu = False if arg['--theta']: - theta = theta_1 = theta_2 = float(arg['--theta']) + theta = float(arg['--theta']) if arg['--gamma']: gamma = float(arg['--gamma']) @@ -113,13 +111,6 @@ dir_coords = [(0,-1), (1,0), (0,1), (-1,0)] def argmax(l): return max(range(len(l)), key=lambda i:l[i]) -def merge_dicts(dicts): - result = defaultdict(lambda:0.) - for factor, d in dicts: - for k in d: - result[k] += factor * d[k] - return result - def draw_randomly(d): c = 0. rnd = random() @@ -128,34 +119,7 @@ def draw_randomly(d): if rnd < c: return k -def visualize(maze, P): - n=0 - for y in range(len(maze)): - line1="" - line2="" - line3="" - for x in range(len(maze[0])): - if maze[y][x] == 1: - line1 += "@" * (2+7) - line3 += "@" * (2+7) - line2 += "@@@%03.1f@@@" % P[y][x] - elif maze[y][x] == 2: - line1 += "." * (2+7) - line3 += "." * (2+7) - line2 += ".%07.5f." % P[y][x] - else: - line1 += "'" + " " * (7) + " " - line3 += " " + " " * (7) + " " - line2 += " %07.5f " % P[y][x] - print(line1) - print(line3) - print(line2) - print(line3) - print(line3) - n+=5 - return n - -def visualize2(maze, Q): +def visualize(maze, Q): n=0 for y in range(len(maze)): line1="" @@ -229,10 +193,6 @@ class World: return self.R((x,y),ss, None), ss - def P(self, s, ss, pi): - return merge_dicts([(pi[d], self.action(s[0],s[1], d)) for d in directions])[ss] - - def R(self, s, ss, pi): if s!=ss and self.maze[ss[1]][ss[0]] == 2: # goal return 10.0 @@ -242,8 +202,6 @@ class World: def is_final(self,s): return self.maze[s[1]][s[0]] == 2 -theta = theta_2 - a = World(maze, start) Q = [ [ [0. for k in range(4)] for i in range(a.xlen) ] for j in range(a.ylen) ] @@ -277,7 +235,7 @@ for i in range(1000000): print("iteration %.3d, alpha=%.3e, epsilon=%.3e maxdiff=%.7f"%(i,alpha,epsilon,maxdiff)) n = 0 if visu: - n = visualize2(maze,Q) + n = visualize(maze,Q) cursor_up(n+2) if (logfile != None): @@ -301,9 +259,9 @@ for i in range(1000000): if stopstate == 0: break else: - stopstate = 100 + stopstate = 1000 print("finished after %.3d iterations, alpha=%.3e, epsilon=%.3e"%(i,alpha,epsilon)) -visualize2(maze,Q) +visualize(maze,Q) if logfile != None: logfile.close() -- cgit v1.2.3