Sarsa 公式如下 : $(Q(s,a)=Q(s,a)+lr[r+\gamma*Q(s’,a’)-Q(s,a)])$
完整代碼
Sarsa類別如下
from Brain import Brain class Sarsa(Brain): def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): super().__init__(actions, learning_rate, reward_decay, e_greedy) def q_value(self, s, action, reward, s_next, action_next): self.check_state_exits(s_next) predict=self.table.loc[s, action] if s_next !='terminal': target=reward+self.gamma*self.table.loc[s_next, action_next] else: target=reward self.table.loc[s, action]+=self.lr*(target - predict)
主程式完整代碼如下
import pandas as pd from Brain import Brain from Maze import Maze from Sarsa import Sarsa def update(): for epoch in range(100): s=maze.reset() while True: action=rl.choose_action(str(s)) #計算下一個狀態,回報值,是否終止 s_next, reward, done=maze.step(action) action_next=rl.choose_action(str(s_next)) #使用 reward 計算 s 狀態每個動作的 Q 值 rl.q_value(str(s), action, reward, str(s_next), action_next) s=s_next maze.render()#沒有render,看不到紅色矩型在移動 if done: break print(f'epoch : {epoch}') print("================final table==================") df=pd.DataFrame( {"up": rl.table[0], "down" : rl.table[1], "left": rl.table[2], "right" : rl.table[3], "max": rl.table.max(axis=1) }, index=rl.table.index ) print(df.applymap(lambda x:'%.5f' % x)) if __name__=='__main__': maze=Maze() rl=Sarsa(actions=list(range(4))) maze.after(100,update) maze.mainloop()
todo