summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sol.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/sol.py b/sol.py
index af35140..20e9fa9 100644
--- a/sol.py
+++ b/sol.py
@@ -48,10 +48,8 @@ print(arg)
mode = None
if arg['-h']:
- print("Usage: %s MODE [OPTIONS]" % argv[0])
- print(" MODE: -1 / --policy-evaluation or\n" +
- " -2 / --q-learning\n" +
- " OPTIONS: --theta NUM # convergence threshold\n" +
+ print("Usage: %s [OPTIONS]" % argv[0])
+ print(" OPTIONS: --theta NUM # convergence threshold\n" +
" # default: %f\n" % theta +
" --gamma NUM # learning discount\n" +
" # default: %f\n" % gamma +
@@ -66,6 +64,10 @@ if arg['-h']:
" # default: %f\n" % epsilon +
" --epsred NUM # reduction of epsilon per episode\n" +
" # default: %f\n\n" % epsilon_reduction +
+ " --qfunc TYPE # type of the Q function's representation\n" +
+ " arr / array -> plain standard array\n" +
+ " nn -> neural network representation\n" +
+ " default: array" +
" --frameskip NUM # frameskip for visualisation\n" +
" # default: %f\n" % frameskip +
" --quiet # disable visualisation\n" +
@@ -249,8 +251,12 @@ class QNN:
self.NN.train(list(s), [x/10. for x in newval])
a = World(maze, start)
-#Q=QArray()
-Q=QNN()
+
+Q = None
+if arg['--qfunc'] == "nn":
+ Q = QNN()
+else:
+ Q = QArray()
i=0
stopstate = -1