Sarsa(lambda)算法实现
from random import random
from gym import Env
import gym
from gridworld import *
class Agent():
def __init__(self,env: Env):
self.env = env #个体持有环境的引用
self.Q = {} #个体维护一张行为价值表
self.E = {} #Eligibility Trace
self.state = None #个体当前的观测
self._init_agent()
def _init_agent(self):
self.state = self.env.reset()
s_name = self._get_state_name(self.state)
self._init_state_value(s_name, randomized=False)
def performPolicy(self, state): #执行一个策略
pass
def act(self,a): #执行一个行为
return self.env.step(a)
def learning(self): #学习过程
pass
def _get_state_name(self,state): #将观测状态转换为一个字典的键
return str(state)
def _is_state_in_Q(self, s): #判断s的Q值是否存在
return self.Q.get(s) is not None
def _is_state_in_E(self, s): #判断s的E值是否存在
return self.E.get(s)
def _init_state_value(self, s_name, randomized = True): #初始化某状态的Q值
if not self._is_state_in_Q(s_name):
self.Q[s_name],self.E[s_name] = {},{}
for action in range(self.env.action_space.n):
default_v = random() / 10 if randomized is True else 0.0
self.Q[s_name][action] = default_v
self.E[s_name][action] = 0.0
def _assert_state_in_Q(self, s, randomized=True): #确保某状态Q的值存在
if not self._is_state_in_Q(s):
self._init_state_value(s, randomized)
def _assert_state_in_E(self, s, randomized=True): #确保E值存在
if not self._is_state_in_E(s):
self._init_state_value(s, randomized)
def _get_Q(self, s, a): #获取Q(s,a)
self._assert_state_in_Q(s, randomized=True)
return self.Q[s][a]
def _set_Q(self, s, a, value): #设置Q(s,a)
self._assert_state_in_Q(s, randomized=True)
self.Q[s][a] = value
def _get_E(self, s, a): #获取E(s,a)
self._assert_state_in_E(s, randomized=True)
return self.E[s][a]
def _set_E(self, s, a, value): #设置E(s,a)
self._assert_state_in_E(s, randomized=True)
self.E[s][a] = value
def _resetEValue(self): #重置E表
for value_dic in self.E.values():
for action in range(self.env.action_space.n):
value_dic[action] = 0.00
def _performPolicy(self, s, episode_num, use_epsilon):
epsilon = 1.00 / (episode_num+1)
Q_s = self.Q[s]
str_act = 'unknown'
rand_value = random()
action = None
if use_epsilon and rand_value < epsilon:
action = self.env.action_space.sample()
else:
str_act = max(Q_s, key=Q_s.get)
action = int(str_act)
return action
def _learning(self, lambda_, gamma, alpha, max_episode_num):
total_time, time_in_episode, num_episode = 0,0,1
while num_episode <= max_episode_num: #设置终止条件
self._resetEValue() #一轮Episode,重新设置表E值
self.state = self.env.reset() #环境初始化
s0 = self._get_state_name(self.state) #获取个体对于观测的命名
self._assert_state_in_Q(s0, randomized=True)
self.env.render() #显示UI界面
a0 = self._performPolicy(s0, num_episode, use_epsilon=True)
time_in_episode = 0
is_done = False
while not is_done: #针对一个Eposide
s1,r1,is_done,info = self.act(a0) #执行策略行为
self.env.render() #更新UI界面
s1 = self._get_state_name(s1) #获取个体对于新状态的命名
self._assert_state_in_Q(s1, randomized=True)
a1 = self._performPolicy(s1, num_episode, use_epsilon=True)
old_q = self._get_Q(s0, a0)
q_prime = self._get_Q(s1, a1)
td_target = r1 + gamma*q_prime
td_error = td_target - old_q
e = self._get_E(s0, a0)
e = e+1
self._set_E(s0,a0,e)
state_action_list = list(zip(self.E.keys(),self.E.values()))
for s,a_es in state_action_list:
for a in range(self.env.action_space.n):
e_value = a_es[a]
old_q = self._get_Q(s,a)
new_q = old_q + alpha*td_error*e_value
new_e = gamma*lambda_*e_value
self._set_Q(s,a,new_q)
self._set_E(s,a,new_e)
if num_episode == max_episode_num:
print("t:{0:>2}: s:{1}, a:{2:2}, s1:{3}".format(time_in_episode,s0,a0,s1))
s0,a0 = s1,a1
time_in_episode += 1
print("Eposide {0} takes {1} steps.".format(num_episode, time_in_episode))
total_time += time_in_episode
num_episode += 1
return
def main():
env = SimpleGridWorld()
agent = Agent(env)
print("learning...")
agent._learning(lambda_=0.01, gamma=0.9, alpha=0.1, max_episode_num=800)
return
if __name__ == '__main__':
main()