Sarsa 公式如下 : Q(s,a)=Q(s,a)+lr[r+γ∗Q(s′,a′)−Q(s,a)]
完整代碼
Sarsa類別如下
from Brain import Brain
class Sarsa(Brain):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super().__init__(actions, learning_rate, reward_decay, e_greedy)
def q_value(self, s, action, reward, s_next, action_next):
self.check_state_exits(s_next)
predict=self.table.loc[s, action]
if s_next !='terminal':
target=reward+self.gamma*self.table.loc[s_next, action_next]
else:
target=reward
self.table.loc[s, action]+=self.lr*(target - predict)
主程式完整代碼如下
import pandas as pd
from Brain import Brain
from Maze import Maze
from Sarsa import Sarsa
def update():
for epoch in range(100):
s=maze.reset()
while True:
action=rl.choose_action(str(s))
#計算下一個狀態,回報值,是否終止
s_next, reward, done=maze.step(action)
action_next=rl.choose_action(str(s_next))
#使用 reward 計算 s 狀態每個動作的 Q 值
rl.q_value(str(s), action, reward, str(s_next), action_next)
s=s_next
maze.render()#沒有render,看不到紅色矩型在移動
if done:
break
print(f'epoch : {epoch}')
print("================final table==================")
df=pd.DataFrame(
{"up": rl.table[0],
"down" : rl.table[1],
"left": rl.table[2],
"right" : rl.table[3],
"max": rl.table.max(axis=1)
},
index=rl.table.index
)
print(df.applymap(lambda x:'%.5f' % x))
if __name__=='__main__':
maze=Maze()
rl=Sarsa(actions=list(range(4)))
maze.after(100,update)
maze.mainloop()
todo
