summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sol.py23
1 files changed, 14 insertions, 9 deletions
diff --git a/sol.py b/sol.py
index c01d7e1..e369b79 100644
--- a/sol.py
+++ b/sol.py
@@ -280,32 +280,36 @@ class QNN (QCommon):
self.dumbtraining = False
connection_rate = 1
num_input = 2
- #hidden = (20,20)
- hidden = (20,10,7)
+ #hidden = (40,40)
+ hidden = (50,)
+ #hidden = (20,10,7)
num_output = 4
learning_rate = 0.7
self.NN = libfann.neural_net()
#self.NN.set_training_algorithm(libfann.TRAIN_BATCH)
- self.NN.set_training_algorithm(libfann.TRAIN_RPROP)
+ #self.NN.set_training_algorithm(libfann.TRAIN_RPROP)
#self.NN.set_training_algorithm(libfann.TRAIN_QUICKPROP)
self.NN.create_sparse_array(connection_rate, (num_input,)+hidden+(num_output,))
+ self.NN.randomize_weights(-1,1)
self.NN.set_learning_rate(learning_rate)
self.NN.set_activation_function_hidden(libfann.SIGMOID_SYMMETRIC_STEPWISE)
self.NN.set_activation_function_output(libfann.SIGMOID_SYMMETRIC_STEPWISE)
#self.NN.set_activation_function_output(libfann.LINEAR)
def eval(self,x,y = None):
+ #print ("eval "+str(x)+", "+str(y))
if y==None: x,y = x
- return self.NN.run([x,y])
+ #print ("self.NN.run("+str([x/7.,y/5.])+")")
+ return self.NN.run([x/7.,y/5.])
def change(self, s, action, diff):
oldval = self.eval(s)
newval = list(oldval) # copy list
newval[action] += diff
- self.NN.train(list(s), newval)
+ self.NN.train([s[0]/7.,s[1]/5.], newval)
# learn a transition "from oldstate by action into newstate with reward `reward`"
# this does not necessarily mean that the action is instantly trained into the function
@@ -321,7 +325,8 @@ class QNN (QCommon):
self.train_on_minibatch()
def train_on_minibatch(self):
- n = min(30, len(self.learnbuffer))
+ n = min(300, len(self.learnbuffer))
+ if n < 300: return
minibatch = random.sample(self.learnbuffer, n)
inputs = []
@@ -329,15 +334,15 @@ class QNN (QCommon):
for oldstate, action, newstate, reward in minibatch:
diff, val = self.value_update(oldstate, action, newstate, reward)
val[action] += diff
- inputs += [ list(oldstate) ]
+ inputs += [ [oldstate[0]/7., oldstate[1]/5.] ]
outputs += [ val ]
#print("training minibatch of size %i:\n%s\n%s\n\n"%(n, str(inputs), str(outputs)))
training_data = libfann.training_data()
training_data.set_train_data(inputs, outputs)
- #self.NN.train_epoch(training_data)
- self.NN.train_on_data(training_data, 5, 0, 0)
+ self.NN.train_epoch(training_data)
+ #self.NN.train_on_data(training_data, 5, 0, 0)
#print(".")
# must be called on every end-of-episode. might trigger batch-training or whatever.