summaryrefslogtreecommitdiff
path: root/qlearning.py
diff options
context:
space:
mode:
Diffstat (limited to 'qlearning.py')
-rwxr-xr-xqlearning.py263
1 files changed, 129 insertions, 134 deletions
diff --git a/qlearning.py b/qlearning.py
index 389db5a..5a247aa 100755
--- a/qlearning.py
+++ b/qlearning.py
@@ -5,143 +5,138 @@ import numpy as np
# Import snake game
from snake import Snake
-
-
-# Setup QTable
-# Boolean features:
-# Snake go up?
-# Snake go right?
-# Snake go down?
-# Snake go left?
-# Apple at up?
-# Apple at right?
-# Apple at down?
-# Apple at left?
-# Obstacle at up?
-# Obstacle at right?
-# Obstacle at down?
-# Obstacle at left?
-# Queue in front?
-##### Totally 13 boolean features so 2^13=8192 states
-##### Totally 4 actions for the AI (up, right,down,left)
-##### Totally 4*2^13 thus 32768 table entries
-##### Reward +1 when eat an apple
-##### Reward -10 when hit obstacle
-
-qtable=np.zeros((2**13, 4))
-
-
-
-game=Snake(length=1,fps=200,startat=(10,10))
-
-def isWall(h,game):
- if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
- return(True)
- return(False)
-
-
+class QTable:
+ """
+ # Boolean features:
+ # Snake go up?
+ # Snake go right?
+ # Snake go down?
+ # Snake go left?
+ # Apple at up?
+ # Apple at right?
+ # Apple at down?
+ # Apple at left?
+ # Obstacle at up?
+ # Obstacle at right?
+ # Obstacle at down?
+ # Obstacle at left?
+ # Tail in front?
+ ##### Totally 13 boolean features so 2^13=8192 states
+ ##### Totally 4 actions for the AI (up, right,down,left)
+ ##### Totally 4*2^13 thus 32768 table entries
+ ##### Reward +1 when eat an apple
+ ##### Reward -10 when hit obstacle
+ """
+ def __init__(self, file, save_every=10):
+ self.file=file
+ self.save_every=save_every
+ self.update_counter=0
+ if os.path.exists(file):
+ self.qtable=np.loadtxt(file)
+ else:
+ self.qtable=np.zeros((2**13, 4))
+
+
+ def isWall(self,h,game):
+ if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
+ return(True)
+ return(False)
+
+ def get_state(self,game):
+ # First compute usefull values
+ h=game.snake[0]
+ left=(h[0]-1,h[1])
+ right=(h[0]+1,h[1])
+ up=(h[0],h[1]-1)
+ down=(h[0],h[1]+1)
+ a=game.apple
+
+ snake_go_up=(game.direction==12)
+ snake_go_right=(game.direction==3)
+ snake_go_down=(game.direction==6)
+ snake_go_left=(game.direction==9)
+
+ apple_up=(a[1]<h[1])
+ apple_right=(a[0]>h[0])
+ apple_down=(a[1]>h[1])
+ apple_left=(a[0]<h[0])
+
+ obstacle_up=(up in game.snake or self.isWall(up, game))
+ obstacle_right=(right in game.snake or self.isWall(right, game))
+ obstacle_down=(down in game.snake or self.isWall(down, game))
+ obstacle_left=(left in game.snake or self.isWall(left, game))
+
+ tail_in_front=0
+ if game.direction == 3:
+ for x in range(h[0],game.grid_width):
+ if (x,h[1]) in game.snake[1:]:
+ tail_in_front=1
+ break
+ elif game.direction == 9:
+ for x in range(0,h[0]):
+ if (x,h[1]) in game.snake[1:]:
+ tail_in_front=1
+ break
+ elif game.direction == 12:
+ for y in range(0,h[1]):
+ if (h[0],y) in game.snake[1:]:
+ tail_in_front=1
+ break
+ elif game.direction == 6:
+ for y in range(h[1],game.grid_height):
+ if (h[0],y) in game.snake[1:]:
+ tail_in_front=1
+ break
+ # This come from me I do not now if it is the best way to identify a state
+ state=2**12*tail_in_front+2**11*snake_go_up+2**10*snake_go_right+2**9*snake_go_down+2**8*snake_go_left+2**7*apple_up+2**6*apple_right+2**5*apple_down+2**4*apple_left+2**3*obstacle_up+2**2*obstacle_right+2**1*obstacle_down+obstacle_left
+ return(state)
+
+ def apply_bellman(self,state,action,new_state,reward):
+ alpha=0.5
+ gamma=0.9
+ self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
+ self.update_counter+=1
+ if self.update_counter>=self.save_every:
+ np.savetxt(self.file,self.qtable)
+ self.update_counter=0
+
+ def get_action(self,state):
+ # Choose an action
+ action=random.choice((0,1,2,3))
+ if np.max(self.qtable[state]) > 0:
+ #qactions=qtable[state]
+ #options=np.flatnonzero(qactions == np.max(qactions)) # Since Q value might be equals for several actions
+ #action = random.choice(options)
+ action=np.argmax(self.qtable[state])
+ return(action)
+
+
+
+
+
+# Perform learning
+perf=0
last_state=None
last_action=None
-attempt=0
-def event_handler(game,event):
- global last_state,last_action,attempt
-
- h=game.snake[0]
- left=(h[0]-1,h[1])
- right=(h[0]+1,h[1])
- up=(h[0],h[1]-1)
- down=(h[0],h[1]+1)
- a=game.apple
-
- snake_go_up=(game.direction==12)
- snake_go_right=(game.direction==3)
- snake_go_down=(game.direction==6)
- snake_go_left=(game.direction==9)
-
- apple_up=(a[1]<h[1])
- apple_right=(a[0]>h[0])
- apple_down=(a[1]>h[1])
- apple_left=(a[0]<h[0])
-
- obstacle_up=(up in game.snake or isWall(up, game))
- obstacle_right=(right in game.snake or isWall(right, game))
- obstacle_down=(down in game.snake or isWall(down, game))
- obstacle_left=(left in game.snake or isWall(left, game))
-
- queue_in_front=0
- if game.direction == 3:
- for x in range(h[0],game.grid_width):
- if (x,h[1]) in game.snake[1:]:
- queue_in_front=1
- break
- elif game.direction == 9:
- for x in range(0,h[0]):
- if (x,h[1]) in game.snake[1:]:
- queue_in_front=1
- break
- elif game.direction == 12:
- for y in range(0,h[1]):
- if (h[0],y) in game.snake[1:]:
- queue_in_front=1
- break
- elif game.direction == 6:
- for y in range(h[1],game.grid_height):
- if (h[0],y) in game.snake[1:]:
- queue_in_front=1
- break
-
- reward=0
- if event==0:
- attempt+=1
- if event==-1:
- reward=-10
- attempt=0
- elif event==1:
- reward=5
- attempt=0
-
- # This come from me I do not now if it is the best way to identify a state
- state=2**12*queue_in_front+2**11*snake_go_up+2**10*snake_go_right+2**9*snake_go_down+2**8*snake_go_left+2**7*apple_up+2**6*apple_right+2**5*apple_down+2**4*apple_left+2**3*obstacle_up+2**2*obstacle_right+2**1*obstacle_down+obstacle_left
-
- # Choose an action
- action=random.choice((0,1,2,3))
- if np.max(qtable[state]) > 0:
- #qactions=qtable[state]
- #options=np.flatnonzero(qactions == np.max(qactions)) # Since Q value might be equals for several actions
- #action = random.choice(options)
- action=np.argmax(qtable[state])
-
- # Avoid infinite loop
- if attempt>game.grid_height*game.grid_width:
- return(-1)
-
- # Update current state Q
- if last_state != None:
- qtable[last_state,last_action]=qtable[last_state,last_action]+0.7*(reward+0.9*np.max(qtable[state])-qtable[last_state,last_action])
- last_state=state
- last_action=action
-
- # Apply the action
- snake_action=12
- if action==1:
- snake_action=3
- elif action==2:
- snake_action=6
- elif action==3:
- snake_action=9
- game.direction=snake_action
- return(0)
-
-if os.path.exists("qtable.txt"):
- qtable=np.loadtxt("qtable.txt")
+game=Snake(length=4,fps=300,startat=(10,10))
+qtable=QTable("qtable.txt")
-perf=0
for i in range(0,10000):
- last_state=None
- last_action=None
- score=game.run(event_handler=event_handler)
- attempt=0
- if i%10 == 0:
- np.savetxt('qtable.txt',qtable)
+ result=0
+ while result >= 0:
+ state=qtable.get_state(game)
+ action=qtable.get_action(state)
+ result=game.play3(action)
+ if last_state!=None:
+ reward=0
+ if result==-1:
+ reward=-10
+ elif result==1:
+ reward=1
+ qtable.apply_bellman(last_state,last_action,state,reward)
+ last_state=state
+ last_action=action
+ # Measurements
+ score=game.last_score
perf=max(perf,score)
print("Game ended with "+str(score)+" best so far is "+str(perf)) \ No newline at end of file