diff options
Diffstat (limited to 'qlearning.py')
| -rwxr-xr-x | qlearning.py | 263 |
1 files changed, 129 insertions, 134 deletions
diff --git a/qlearning.py b/qlearning.py index 389db5a..5a247aa 100755 --- a/qlearning.py +++ b/qlearning.py @@ -5,143 +5,138 @@ import numpy as np # Import snake game from snake import Snake - - -# Setup QTable -# Boolean features: -# Snake go up? -# Snake go right? -# Snake go down? -# Snake go left? -# Apple at up? -# Apple at right? -# Apple at down? -# Apple at left? -# Obstacle at up? -# Obstacle at right? -# Obstacle at down? -# Obstacle at left? -# Queue in front? -##### Totally 13 boolean features so 2^13=8192 states -##### Totally 4 actions for the AI (up, right,down,left) -##### Totally 4*2^13 thus 32768 table entries -##### Reward +1 when eat an apple -##### Reward -10 when hit obstacle - -qtable=np.zeros((2**13, 4)) - - - -game=Snake(length=1,fps=200,startat=(10,10)) - -def isWall(h,game): - if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height: - return(True) - return(False) - - +class QTable: + """ + # Boolean features: + # Snake go up? + # Snake go right? + # Snake go down? + # Snake go left? + # Apple at up? + # Apple at right? + # Apple at down? + # Apple at left? + # Obstacle at up? + # Obstacle at right? + # Obstacle at down? + # Obstacle at left? + # Tail in front? + ##### Totally 13 boolean features so 2^13=8192 states + ##### Totally 4 actions for the AI (up, right,down,left) + ##### Totally 4*2^13 thus 32768 table entries + ##### Reward +1 when eat an apple + ##### Reward -10 when hit obstacle + """ + def __init__(self, file, save_every=10): + self.file=file + self.save_every=save_every + self.update_counter=0 + if os.path.exists(file): + self.qtable=np.loadtxt(file) + else: + self.qtable=np.zeros((2**13, 4)) + + + def isWall(self,h,game): + if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height: + return(True) + return(False) + + def get_state(self,game): + # First compute usefull values + h=game.snake[0] + left=(h[0]-1,h[1]) + right=(h[0]+1,h[1]) + up=(h[0],h[1]-1) + down=(h[0],h[1]+1) + a=game.apple + + snake_go_up=(game.direction==12) + snake_go_right=(game.direction==3) + snake_go_down=(game.direction==6) + snake_go_left=(game.direction==9) + + apple_up=(a[1]<h[1]) + apple_right=(a[0]>h[0]) + apple_down=(a[1]>h[1]) + apple_left=(a[0]<h[0]) + + obstacle_up=(up in game.snake or self.isWall(up, game)) + obstacle_right=(right in game.snake or self.isWall(right, game)) + obstacle_down=(down in game.snake or self.isWall(down, game)) + obstacle_left=(left in game.snake or self.isWall(left, game)) + + tail_in_front=0 + if game.direction == 3: + for x in range(h[0],game.grid_width): + if (x,h[1]) in game.snake[1:]: + tail_in_front=1 + break + elif game.direction == 9: + for x in range(0,h[0]): + if (x,h[1]) in game.snake[1:]: + tail_in_front=1 + break + elif game.direction == 12: + for y in range(0,h[1]): + if (h[0],y) in game.snake[1:]: + tail_in_front=1 + break + elif game.direction == 6: + for y in range(h[1],game.grid_height): + if (h[0],y) in game.snake[1:]: + tail_in_front=1 + break + # This come from me I do not now if it is the best way to identify a state + state=2**12*tail_in_front+2**11*snake_go_up+2**10*snake_go_right+2**9*snake_go_down+2**8*snake_go_left+2**7*apple_up+2**6*apple_right+2**5*apple_down+2**4*apple_left+2**3*obstacle_up+2**2*obstacle_right+2**1*obstacle_down+obstacle_left + return(state) + + def apply_bellman(self,state,action,new_state,reward): + alpha=0.5 + gamma=0.9 + self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action]) + self.update_counter+=1 + if self.update_counter>=self.save_every: + np.savetxt(self.file,self.qtable) + self.update_counter=0 + + def get_action(self,state): + # Choose an action + action=random.choice((0,1,2,3)) + if np.max(self.qtable[state]) > 0: + #qactions=qtable[state] + #options=np.flatnonzero(qactions == np.max(qactions)) # Since Q value might be equals for several actions + #action = random.choice(options) + action=np.argmax(self.qtable[state]) + return(action) + + + + + +# Perform learning +perf=0 last_state=None last_action=None -attempt=0 -def event_handler(game,event): - global last_state,last_action,attempt - - h=game.snake[0] - left=(h[0]-1,h[1]) - right=(h[0]+1,h[1]) - up=(h[0],h[1]-1) - down=(h[0],h[1]+1) - a=game.apple - - snake_go_up=(game.direction==12) - snake_go_right=(game.direction==3) - snake_go_down=(game.direction==6) - snake_go_left=(game.direction==9) - - apple_up=(a[1]<h[1]) - apple_right=(a[0]>h[0]) - apple_down=(a[1]>h[1]) - apple_left=(a[0]<h[0]) - - obstacle_up=(up in game.snake or isWall(up, game)) - obstacle_right=(right in game.snake or isWall(right, game)) - obstacle_down=(down in game.snake or isWall(down, game)) - obstacle_left=(left in game.snake or isWall(left, game)) - - queue_in_front=0 - if game.direction == 3: - for x in range(h[0],game.grid_width): - if (x,h[1]) in game.snake[1:]: - queue_in_front=1 - break - elif game.direction == 9: - for x in range(0,h[0]): - if (x,h[1]) in game.snake[1:]: - queue_in_front=1 - break - elif game.direction == 12: - for y in range(0,h[1]): - if (h[0],y) in game.snake[1:]: - queue_in_front=1 - break - elif game.direction == 6: - for y in range(h[1],game.grid_height): - if (h[0],y) in game.snake[1:]: - queue_in_front=1 - break - - reward=0 - if event==0: - attempt+=1 - if event==-1: - reward=-10 - attempt=0 - elif event==1: - reward=5 - attempt=0 - - # This come from me I do not now if it is the best way to identify a state - state=2**12*queue_in_front+2**11*snake_go_up+2**10*snake_go_right+2**9*snake_go_down+2**8*snake_go_left+2**7*apple_up+2**6*apple_right+2**5*apple_down+2**4*apple_left+2**3*obstacle_up+2**2*obstacle_right+2**1*obstacle_down+obstacle_left - - # Choose an action - action=random.choice((0,1,2,3)) - if np.max(qtable[state]) > 0: - #qactions=qtable[state] - #options=np.flatnonzero(qactions == np.max(qactions)) # Since Q value might be equals for several actions - #action = random.choice(options) - action=np.argmax(qtable[state]) - - # Avoid infinite loop - if attempt>game.grid_height*game.grid_width: - return(-1) - - # Update current state Q - if last_state != None: - qtable[last_state,last_action]=qtable[last_state,last_action]+0.7*(reward+0.9*np.max(qtable[state])-qtable[last_state,last_action]) - last_state=state - last_action=action - - # Apply the action - snake_action=12 - if action==1: - snake_action=3 - elif action==2: - snake_action=6 - elif action==3: - snake_action=9 - game.direction=snake_action - return(0) - -if os.path.exists("qtable.txt"): - qtable=np.loadtxt("qtable.txt") +game=Snake(length=4,fps=300,startat=(10,10)) +qtable=QTable("qtable.txt") -perf=0 for i in range(0,10000): - last_state=None - last_action=None - score=game.run(event_handler=event_handler) - attempt=0 - if i%10 == 0: - np.savetxt('qtable.txt',qtable) + result=0 + while result >= 0: + state=qtable.get_state(game) + action=qtable.get_action(state) + result=game.play3(action) + if last_state!=None: + reward=0 + if result==-1: + reward=-10 + elif result==1: + reward=1 + qtable.apply_bellman(last_state,last_action,state,reward) + last_state=state + last_action=action + # Measurements + score=game.last_score perf=max(perf,score) print("Game ended with "+str(score)+" best so far is "+str(perf))
\ No newline at end of file |
