diff options
-rw-r--r-- | sol.py | 18 |
1 files changed, 12 insertions, 6 deletions
@@ -48,10 +48,8 @@ print(arg) mode = None if arg['-h']: - print("Usage: %s MODE [OPTIONS]" % argv[0]) - print(" MODE: -1 / --policy-evaluation or\n" + - " -2 / --q-learning\n" + - " OPTIONS: --theta NUM # convergence threshold\n" + + print("Usage: %s [OPTIONS]" % argv[0]) + print(" OPTIONS: --theta NUM # convergence threshold\n" + " # default: %f\n" % theta + " --gamma NUM # learning discount\n" + " # default: %f\n" % gamma + @@ -66,6 +64,10 @@ if arg['-h']: " # default: %f\n" % epsilon + " --epsred NUM # reduction of epsilon per episode\n" + " # default: %f\n\n" % epsilon_reduction + + " --qfunc TYPE # type of the Q function's representation\n" + + " arr / array -> plain standard array\n" + + " nn -> neural network representation\n" + + " default: array" + " --frameskip NUM # frameskip for visualisation\n" + " # default: %f\n" % frameskip + " --quiet # disable visualisation\n" + @@ -249,8 +251,12 @@ class QNN: self.NN.train(list(s), [x/10. for x in newval]) a = World(maze, start) -#Q=QArray() -Q=QNN() + +Q = None +if arg['--qfunc'] == "nn": + Q = QNN() +else: + Q = QArray() i=0 stopstate = -1 |