summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sol.py58
1 files changed, 8 insertions, 50 deletions
diff --git a/sol.py b/sol.py
index f98421d..7e3a0d4 100644
--- a/sol.py
+++ b/sol.py
@@ -10,9 +10,7 @@ maze = [ [0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 2] ]
start=(1,1)
-theta = None
-theta_1 = 0.00001
-theta_2 = 0.01
+theta = 0.01
gamma=0.9 # discount
epsilon = 0.1
epsilon_reduction = 1
@@ -53,8 +51,8 @@ if arg['-h']:
print(" MODE: -1 / --policy-evaluation or\n" +
" -2 / --q-learning\n" +
" OPTIONS: --theta NUM # convergence threshold\n" +
- " # default: %f / %f for -1 / -2\n" % (theta_1, theta_2) +
- " --gamma NUM # learning discount for value iteration\n" +
+ " # default: %f\n" % theta +
+ " --gamma NUM # learning discount\n" +
" # default: %f\n" % gamma +
" --alpha NUM # learning rate for q-learning\n" +
" # default: %f\n" % alpha +
@@ -78,7 +76,7 @@ if arg['-q'] or arg['--quiet']:
visu = False
if arg['--theta']:
- theta = theta_1 = theta_2 = float(arg['--theta'])
+ theta = float(arg['--theta'])
if arg['--gamma']:
gamma = float(arg['--gamma'])
@@ -113,13 +111,6 @@ dir_coords = [(0,-1), (1,0), (0,1), (-1,0)]
def argmax(l):
return max(range(len(l)), key=lambda i:l[i])
-def merge_dicts(dicts):
- result = defaultdict(lambda:0.)
- for factor, d in dicts:
- for k in d:
- result[k] += factor * d[k]
- return result
-
def draw_randomly(d):
c = 0.
rnd = random()
@@ -128,34 +119,7 @@ def draw_randomly(d):
if rnd < c:
return k
-def visualize(maze, P):
- n=0
- for y in range(len(maze)):
- line1=""
- line2=""
- line3=""
- for x in range(len(maze[0])):
- if maze[y][x] == 1:
- line1 += "@" * (2+7)
- line3 += "@" * (2+7)
- line2 += "@@@%03.1f@@@" % P[y][x]
- elif maze[y][x] == 2:
- line1 += "." * (2+7)
- line3 += "." * (2+7)
- line2 += ".%07.5f." % P[y][x]
- else:
- line1 += "'" + " " * (7) + " "
- line3 += " " + " " * (7) + " "
- line2 += " %07.5f " % P[y][x]
- print(line1)
- print(line3)
- print(line2)
- print(line3)
- print(line3)
- n+=5
- return n
-
-def visualize2(maze, Q):
+def visualize(maze, Q):
n=0
for y in range(len(maze)):
line1=""
@@ -229,10 +193,6 @@ class World:
return self.R((x,y),ss, None), ss
- def P(self, s, ss, pi):
- return merge_dicts([(pi[d], self.action(s[0],s[1], d)) for d in directions])[ss]
-
-
def R(self, s, ss, pi):
if s!=ss and self.maze[ss[1]][ss[0]] == 2: # goal
return 10.0
@@ -242,8 +202,6 @@ class World:
def is_final(self,s):
return self.maze[s[1]][s[0]] == 2
-theta = theta_2
-
a = World(maze, start)
Q = [ [ [0. for k in range(4)] for i in range(a.xlen) ] for j in range(a.ylen) ]
@@ -277,7 +235,7 @@ for i in range(1000000):
print("iteration %.3d, alpha=%.3e, epsilon=%.3e maxdiff=%.7f"%(i,alpha,epsilon,maxdiff))
n = 0
if visu:
- n = visualize2(maze,Q)
+ n = visualize(maze,Q)
cursor_up(n+2)
if (logfile != None):
@@ -301,9 +259,9 @@ for i in range(1000000):
if stopstate == 0:
break
else:
- stopstate = 100
+ stopstate = 1000
print("finished after %.3d iterations, alpha=%.3e, epsilon=%.3e"%(i,alpha,epsilon))
-visualize2(maze,Q)
+visualize(maze,Q)
if logfile != None:
logfile.close()