summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sol.py71
1 files changed, 55 insertions, 16 deletions
diff --git a/sol.py b/sol.py
index 84c533e..d0faf70 100644
--- a/sol.py
+++ b/sol.py
@@ -213,10 +213,35 @@ class World:
return self.maze[s[1]][s[0]] == 2
+class QCommon:
+ def __init__(self):
+ self.alpha = alpha
+ self.gamma = gamma
+ self.maxdiff = 0.0
+
+ # returns a pair of (diff, self.eval(oldstate)).
+ # diff is meant to be added to self.eval(oldstate)[action]
+ def value_update(self, oldstate, action, newstate, reward):
+ eval_newstate = self.eval(newstate)
+ eval_oldstate = self.eval(oldstate)
+ diff = self.alpha * (reward + self.gamma * max( [ eval_newstate[aa] for aa in directions ] ) - eval_oldstate[action])
+
+ if self.maxdiff < diff: self.maxdiff = diff
+
+ return diff, eval_oldstate
+
+ def episode(self):
+ maxdiff = self.maxdiff
+ self.maxdiff = 0.0
+ return maxdiff
+
+
# abstracts the Q-array. semantics of .eval(x,y) is `Q[y][x]`. semantics of .change((x,y),ac,diff) is `Q[y][x][ac]+=diff`
-class QArray:
+class QArray(QCommon):
def __init__(self):
+ super().__init__()
self.Q = [ [ [0. for k in range(4)] for i in range(a.xlen) ] for j in range(a.ylen) ]
+ self.learnbuffer = []
# calculates Q(x,y)
def eval(self,x,y = None):
@@ -224,16 +249,24 @@ class QArray:
return self.Q[y][x]
+ def learn(self, oldstate, action, newstate, reward):
+ self.learnbuffer += [(oldstate,action,newstate,reward)]
+ # self.flush_learnbuffer() # TODO TRYME
- def change(self, s, action, diff):
- self.Q[s[1]][s[0]][action] += diff
+ def flush_learnbuffer(self):
+ for oldstate, action, newstate, reward in reversed(self.learnbuffer):
+ diff,_ = self.value_update(oldstate,action,newstate,reward)
+ self.Q[oldstate[1]][oldstate[0]][action] += diff
+ self.learnbuffer = []
def episode(self):
- pass
+ self.flush_learnbuffer()
+ return super().episode()
# implements the Q function not through an array, but through a neuronal network instead.
-class QNN:
+class QNN (QCommon):
def __init__(self):
+ super().__init__()
connection_rate = 1
num_input = 2
hidden = (40,40)
@@ -260,8 +293,16 @@ class QNN:
self.NN.train(list(s), [x/10. for x in newval])
- def episode(self):
- pass
+ # learn a transition "from oldstate by action into newstate with reward `reward`"
+ # this does not necessarily mean that the action is instantly trained into the function
+ # representation. It may be held back in a list, to be batch-trained lated.
+ def learn(self, oldstate, action, newstate, reward):
+ diff = self.value_update(oldstate,action,newstate,reward)
+ Q.change(oldstate,action,diff)
+
+ # must be called on every end-of-episode. might trigger batch-training or whatever.
+ #def episode(self):
+ # pass
a = World(maze, start)
@@ -277,7 +318,6 @@ total_reward = 0.
for i in range(n_episodes):
s = start
- maxdiff=0.
for j in range(100):
# epsilon-greedy
greedy = argmax(Q.eval(s))
@@ -289,18 +329,17 @@ for i in range(n_episodes):
action = greedy
r,ss = a.take_action(s[0],s[1], action)
- #print ((r,ss))
- diff = alpha * (r + gamma * max( [ Q.eval(ss)[aa] for aa in directions ] ) - Q.eval(s)[action])
- Q.change(s,action,diff)
- maxdiff = max(abs(diff),maxdiff)
+
+ Q.learn(s,action,ss,r)
+
total_reward += r
s = ss
if a.is_final(ss):
break
- Q.episode()
+ maxdiff = Q.episode()
if (i % (frameskip+1) == 0):
- print("iteration %.3d, alpha=%.3e, epsilon=%.3e maxdiff=%.7f"%(i,alpha,epsilon,maxdiff))
+ print("iteration %.3d, alpha=%.3e, epsilon=%.3e maxdiff=%.7f"%(i,Q.alpha,epsilon,maxdiff))
n = 0
if visu:
n = visualize(maze,Q)
@@ -317,7 +356,7 @@ for i in range(n_episodes):
# However, if we set the friendlyness of our system to 1.0, then it would
# also converge without this learning rate reduction, because we have a
# non-stochastic but a deterministic system now.
- alpha *= alpha_reduction
+ Q.alpha *= alpha_reduction
epsilon *= epsilon_reduction
@@ -329,7 +368,7 @@ for i in range(n_episodes):
else:
stopstate = 1000
-print("finished after %.3d iterations, alpha=%.3e, epsilon=%.3e"%(i,alpha,epsilon))
+print("finished after %.3d iterations, alpha=%.3e, epsilon=%.3e"%(i,Q.alpha,epsilon))
visualize(maze,Q)
if logfile != None:
logfile.close()