summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sol.py33
1 files changed, 32 insertions, 1 deletions
diff --git a/sol.py b/sol.py
index fa1de97..cbccf67 100644
--- a/sol.py
+++ b/sol.py
@@ -2,6 +2,7 @@ from random import random
from itertools import product
from collections import defaultdict
from sys import argv
+from fann2 import libfann
maze = [ [0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
@@ -218,8 +219,38 @@ class QArray:
self.Q[s[1]][s[0]][action] += diff
+# implements the Q function not through an array, but through a neuronal network instead.
+class QNN:
+ def __init__(self):
+ connection_rate = 1
+ num_input = 2
+ hidden = (50,50)
+ num_output = 4
+ learning_rate = 0.7
+
+ self.NN = libfann.neural_net()
+ self.NN.create_sparse_array(connection_rate, (num_input,)+hidden+(num_output,))
+ self.NN.set_learning_rate(learning_rate)
+ #self.NN.set_activation_function_input(libfann.SIGMOID_SYMMETRIC_STEPWISE)
+ self.NN.set_activation_function_hidden(libfann.SIGMOID_SYMMETRIC_STEPWISE)
+ self.NN.set_activation_function_output(libfann.SIGMOID_SYMMETRIC_STEPWISE)
+ #self.NN.set_activation_function_output(libfann.LINEAR)
+
+ def eval(self,x,y = None):
+ if y==None: x,y = x
+
+ return [x*10. for x in self.NN.run([x,y])]
+
+ def change(self, s, action, diff):
+ oldval = self.eval(s)
+ newval = list(oldval) # copy list
+ newval[action] += diff
+
+ self.NN.train(list(s), [x/10. for x in newval])
+
a = World(maze, start)
-Q=QArray()
+#Q=QArray()
+Q=QNN()
i=0
stopstate = -1