Sarsa公式

      在〈Sarsa公式〉中尚無留言

Sarsa 公式如下 : $(Q(s,a)=Q(s,a)+lr[r+\gamma*Q(s’,a’)-Q(s,a)])$

完整代碼

Sarsa類別如下

from Brain import Brain
class Sarsa(Brain):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super().__init__(actions, learning_rate, reward_decay, e_greedy)
    def q_value(self, s, action, reward, s_next, action_next):
        self.check_state_exits(s_next)
        predict=self.table.loc[s, action]
        if s_next !='terminal':
            target=reward+self.gamma*self.table.loc[s_next, action_next]
        else:
            target=reward
        self.table.loc[s, action]+=self.lr*(target - predict)

主程式完整代碼如下

import pandas as pd
from Brain import Brain
from Maze import Maze
from Sarsa import Sarsa

def update():
    for epoch in range(100):
        s=maze.reset()
        while True:
            action=rl.choose_action(str(s))
            #計算下一個狀態,回報值,是否終止
            s_next, reward, done=maze.step(action)
            action_next=rl.choose_action(str(s_next))
            #使用 reward 計算 s 狀態每個動作的 Q 值
            rl.q_value(str(s), action, reward, str(s_next), action_next)
            s=s_next
            maze.render()#沒有render,看不到紅色矩型在移動
            if done:
                break
        print(f'epoch : {epoch}')
    print("================final table==================")
    df=pd.DataFrame(
        {"up": rl.table[0],
         "down" : rl.table[1],
         "left": rl.table[2],
         "right" : rl.table[3],
         "max": rl.table.max(axis=1)
         },
        index=rl.table.index
    )
    print(df.applymap(lambda x:'%.5f' % x))
if __name__=='__main__':
    maze=Maze()
    rl=Sarsa(actions=list(range(4)))
    maze.after(100,update)
    maze.mainloop()

todo

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *