0
点赞
收藏
分享

微信扫一扫

Sarsa(lambda)算法实现

Sarsa(lambda)算法实现

from random import random
from gym import Env
import gym
from gridworld import *

class Agent():
    def __init__(self,env: Env):
        self.env = env  #个体持有环境的引用
        self.Q = {}  #个体维护一张行为价值表
        self.E = {}  #Eligibility Trace
        self.state = None  #个体当前的观测
        self._init_agent()
    def _init_agent(self):
        self.state = self.env.reset()
        s_name = self._get_state_name(self.state)
        self._init_state_value(s_name, randomized=False)


    def performPolicy(self, state):  #执行一个策略
        pass

    def act(self,a):  #执行一个行为
        return self.env.step(a)
    def learning(self):  #学习过程
        pass
    def _get_state_name(self,state):  #将观测状态转换为一个字典的键
        return str(state)
    def _is_state_in_Q(self, s):  #判断s的Q值是否存在
        return self.Q.get(s) is not None
    def _is_state_in_E(self, s):  #判断s的E值是否存在
        return self.E.get(s)
    def _init_state_value(self, s_name, randomized = True):  #初始化某状态的Qif not self._is_state_in_Q(s_name):
            self.Q[s_name],self.E[s_name] = {},{}
            for action in range(self.env.action_space.n):
                default_v = random() / 10 if randomized is True else 0.0
                self.Q[s_name][action] = default_v
                self.E[s_name][action] = 0.0
    def _assert_state_in_Q(self, s, randomized=True):  #确保某状态Q的值存在
        if not self._is_state_in_Q(s):
            self._init_state_value(s, randomized)
    def _assert_state_in_E(self, s, randomized=True):  #确保E值存在
        if not self._is_state_in_E(s):
            self._init_state_value(s, randomized)
    def _get_Q(self, s, a):  #获取Q(s,a)
        self._assert_state_in_Q(s, randomized=True)
        return self.Q[s][a]
    def _set_Q(self, s, a, value):  #设置Q(s,a)
        self._assert_state_in_Q(s, randomized=True)
        self.Q[s][a] = value
    def _get_E(self, s, a):  #获取E(s,a)
        self._assert_state_in_E(s, randomized=True)
        return self.E[s][a]
    def _set_E(self, s, a, value):  #设置E(s,a)
        self._assert_state_in_E(s, randomized=True)
        self.E[s][a] = value
    def _resetEValue(self):  #重置Efor value_dic in self.E.values():
            for action in range(self.env.action_space.n):
                value_dic[action] = 0.00
    def _performPolicy(self, s, episode_num, use_epsilon):
        epsilon = 1.00 / (episode_num+1)
        Q_s = self.Q[s]
        str_act = 'unknown'
        rand_value = random()
        action = None
        if use_epsilon and rand_value < epsilon:
            action = self.env.action_space.sample()
        else:
            str_act = max(Q_s, key=Q_s.get)
            action = int(str_act)
        return action

    def _learning(self, lambda_, gamma, alpha, max_episode_num):
        total_time, time_in_episode, num_episode = 0,0,1
        while num_episode <= max_episode_num:  #设置终止条件
            self._resetEValue()  #一轮Episode,重新设置表E值
            self.state = self.env.reset()  #环境初始化
            s0 = self._get_state_name(self.state)  #获取个体对于观测的命名
            self._assert_state_in_Q(s0, randomized=True)
            self.env.render()  #显示UI界面
            a0 = self._performPolicy(s0, num_episode, use_epsilon=True)
            time_in_episode = 0
            is_done = False
            while not is_done:  #针对一个Eposide
                s1,r1,is_done,info = self.act(a0)  #执行策略行为
                self.env.render()  #更新UI界面
                s1 = self._get_state_name(s1)  #获取个体对于新状态的命名
                self._assert_state_in_Q(s1, randomized=True)
                a1 = self._performPolicy(s1, num_episode, use_epsilon=True)
                old_q = self._get_Q(s0, a0)
                q_prime = self._get_Q(s1, a1)
                td_target = r1 + gamma*q_prime
                td_error = td_target - old_q
                e = self._get_E(s0, a0)
                e = e+1
                self._set_E(s0,a0,e)
                state_action_list = list(zip(self.E.keys(),self.E.values()))
                for s,a_es in state_action_list:
                    for a in range(self.env.action_space.n):
                        e_value = a_es[a]
                        old_q = self._get_Q(s,a)
                        new_q = old_q + alpha*td_error*e_value
                        new_e = gamma*lambda_*e_value
                        self._set_Q(s,a,new_q)
                        self._set_E(s,a,new_e)
                if num_episode == max_episode_num:
                    print("t:{0:>2}: s:{1}, a:{2:2}, s1:{3}".format(time_in_episode,s0,a0,s1))
                s0,a0 = s1,a1
                time_in_episode += 1
            print("Eposide {0} takes {1} steps.".format(num_episode, time_in_episode))
            total_time += time_in_episode
            num_episode += 1
        return

def main():
    env = SimpleGridWorld()
    agent = Agent(env)
    print("learning...")
    agent._learning(lambda_=0.01, gamma=0.9, alpha=0.1, max_episode_num=800)
    return

if __name__ == '__main__':
    main()
举报

相关推荐

0 条评论