summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sol.py23
1 files changed, 16 insertions, 7 deletions
diff --git a/sol.py b/sol.py
index 7b21ef4..f7b3a61 100644
--- a/sol.py
+++ b/sol.py
@@ -73,6 +73,7 @@ if arg['-h']:
" default: array" +
" --frameskip NUM # frameskip for visualisation\n" +
" # default: %f\n" % frameskip +
+ " --sleep # sleep after every <frameskip> frames\n"+
" --quiet # disable visualisation\n" +
" --file FILE # output file for q learning")
exit()
@@ -81,6 +82,10 @@ if arg['-h']:
if arg['-q'] or arg['--quiet']:
visu = False
+do_sleep = False
+if arg['--sleep']:
+ do_sleep = True
+
if arg['--frameskip']:
frameskip = int(arg['--frameskip'])
@@ -350,8 +355,19 @@ else:
i=0
stopstate = -1
total_reward = 0.
+maxdiff = -1
for i in range(n_episodes):
+ if (i % (frameskip+1) == 0):
+ print("iteration %.3d, alpha=%.3e, epsilon=%.3e maxdiff=%.7f"%(i,Q.alpha,epsilon,maxdiff))
+ n = 0
+ if visu:
+ n = visualize(maze,Q)
+ cursor_up(n+2)
+
+ if do_sleep:
+ input()
+
s = start
for j in range(100):
# epsilon-greedy
@@ -373,13 +389,6 @@ for i in range(n_episodes):
break
maxdiff = Q.episode()
- if (i % (frameskip+1) == 0):
- print("iteration %.3d, alpha=%.3e, epsilon=%.3e maxdiff=%.7f"%(i,Q.alpha,epsilon,maxdiff))
- n = 0
- if visu:
- n = visualize(maze,Q)
- cursor_up(n+2)
-
if (logfile != None):
print("%d\t%f" % (i, total_reward), file=logfile)