0
点赞
收藏
分享

微信扫一扫

强化学习:DQN 算法在游戏中的应用

在游戏 AI 领域,强化学习凭借其自主学习能力,实现了从简单小游戏到复杂 3A 大作的智能决策。其中深度 Q 网络(DQN)作为连接深度学习与强化学习的里程碑算法,让 AI 首次在 Atari 游戏中达到人类水平。本文将从基础原理出发,通过实战案例讲解 DQN 算法如何在游戏中实现自主决策。

一、DQN 算法核心原理

DQN(Deep Q-Network)的核心思想是用深度神经网络近似 Q 值函数,解决传统强化学习无法处理高维状态空间的问题。

1. Q-learning 基础

Q-learning 通过学习动作价值函数 Q (s,a) 指导决策,其中 s 是当前状态,a 是动作,Q (s,a) 表示在状态 s 下采取动作 a 的预期累积奖励:

def q_learning_update(q_table, state, action, reward, next_state, alpha=0.1, gamma=0.9):    """Q-learning更新公式"""    old_value = q_table[state][action]    # 选择下一状态的最大Q值    next_max = max(q_table[next_state].values())    # 更新Q值:Q(s,a) = Q(s,a) + α[r + γmaxQ(s',a') - Q(s,a)]    new_value = old_value + alpha * (reward + gamma * next_max - old_value)    q_table[state][action] = new_value    return q_table

传统 Q-learning 使用表格存储 Q 值,仅适用于状态空间较小的场景(如迷宫问题)。

2. DQN 的创新点

DQN 通过两个关键技术解决了深度神经网络在强化学习中的不稳定问题:

  • 经验回放(Experience Replay):存储智能体的经验 (s,a,r,s') 到回放池,随机采样打破样本相关性
  • 目标网络(Target Network):使用单独的目标网络计算目标 Q 值,定期同步主网络参数

class ReplayBuffer:    """经验回放池"""    def __init__(self, capacity):        self.buffer = deque(maxlen=capacity)  # 有限容量的双端队列            def add(self, state, action, reward, next_state, done):        """添加经验"""        self.buffer.append((state, action, reward, next_state, done))            def sample(self, batch_size):        """随机采样批次数据"""        transitions = random.sample(self.buffer, batch_size)        # 转换为批量数组        state, action, reward, next_state, done = zip(*transitions)        return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)            def size(self):        """返回当前存储的经验数量"""        return len(self.buffer)

经验回放机制使神经网络的训练样本更符合独立同分布假设,显著提升了训练稳定性。

二、DQN 网络结构与实现

DQN 使用卷积神经网络处理游戏画面,输出每个动作的 Q 值估计。

1. 网络结构设计

import torchimport torch.nn as nnimport torch.optim as optimimport numpy as npfrom collections import dequeimport randomclass DQN(nn.Module):    def __init__(self, state_dim, action_dim, hidden_dim=64):        super(DQN, self).__init__()        # 适用于低维状态的全连接网络(如贪吃蛇)        self.fc1 = nn.Linear(state_dim, hidden_dim)        self.fc2 = nn.Linear(hidden_dim, hidden_dim)        self.fc3 = nn.Linear(hidden_dim, action_dim)  # 输出每个动作的Q值            def forward(self, x):        x = torch.relu(self.fc1(x))        x = torch.relu(self.fc2(x))        return self.fc3(x)class DQN_CNN(nn.Module):    def __init__(self, action_dim):        super(DQN_CNN, self).__init__()        # 适用于图像输入的卷积网络(如Atari游戏)        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)  # 4帧灰度图输入        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)        self.fc1 = nn.Linear(7 * 7 * 64, 512)  # 卷积输出扁平化        self.fc2 = nn.Linear(512, action_dim)            def forward(self, x):        x = torch.relu(self.conv1(x))        x = torch.relu(self.conv2(x))        x = torch.relu(self.conv3(x))        x = x.view(x.size(0), -1)  # 扁平化        x = torch.relu(self.fc1(x))        return self.fc2(x)

网络输出层维度等于动作空间大小,每个输出值对应该动作的 Q 值估计。

2. DQN 算法实现

class DQNAgent:    def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99, epsilon=1.0):        self.action_dim = action_dim        self.gamma = gamma  # 折扣因子        self.epsilon = epsilon  # 探索率        self.epsilon_min = 0.01        self.epsilon_decay = 0.995  # 探索率衰减                # 主网络和目标网络        self.policy_net = DQN(state_dim, action_dim)        self.target_net = DQN(state_dim, action_dim)        self.target_net.load_state_dict(self.policy_net.state_dict())  # 初始化一致                self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)        self.buffer = ReplayBuffer(capacity=10000)        self.batch_size = 64            def select_action(self, state):        """ε-贪婪策略选择动作"""        if random.random() < self.epsilon:            # 随机探索            return random.randint(0, self.action_dim - 1)        else:            # 贪婪选择最大Q值动作            with torch.no_grad():                state = torch.FloatTensor(state).unsqueeze(0)                q_values = self.policy_net(state)                return q_values.max(1)[1].item()                    def update(self):        """从经验回放池中采样并更新网络"""        if self.buffer.size() < self.batch_size:            return  # 经验不足时不更新                    # 采样批次数据        states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)                # 转换为张量        states = torch.FloatTensor(states)        actions = torch.LongTensor(actions).unsqueeze(1)        rewards = torch.FloatTensor(rewards)        next_states = torch.FloatTensor(next_states)        dones = torch.FloatTensor(dones)                # 计算当前Q值:主网络输出的对应动作Q值        current_q = self.policy_net(states).gather(1, actions).squeeze(1)                # 计算目标Q值:目标网络的下一状态最大Q值        with torch.no_grad():            next_q = self.target_net(next_states).max(1)[0]            target_q = rewards + (1 - dones) * self.gamma * next_q                    # 计算损失并优化        loss = nn.MSELoss()(current_q, target_q)        self.optimizer.zero_grad()        loss.backward()        self.optimizer.step()                # 衰减探索率        if self.epsilon > self.epsilon_min:            self.epsilon *= self.epsilon_decay                    return loss.item()            def sync_target_network(self):        """同步目标网络参数"""        self.target_net.load_state_dict(self.policy_net.state_dict())

目标网络定期同步(如每 1000 步),避免 Q 值估计的剧烈波动,提升训练稳定性。

二、实战案例:DQN 玩贪吃蛇

以经典贪吃蛇游戏为例,展示 DQN 如何自主学习避障和吃食物。

1. 游戏环境设计

import pygameimport numpy as npclass SnakeGame:    def __init__(self, width=400, height=400, block_size=20):        self.width = width        self.height = height        self.block_size = block_size        self.reset()            def reset(self):        """重置游戏状态"""        self.snake = [(100, 100), (80, 100), (60, 100)]  # 蛇身体        self.direction = (self.block_size, 0)  # 初始向右        self.food = self._generate_food()        self.score = 0        self.done = False        return self._get_state()            def _generate_food(self):        """随机生成食物位置(不与蛇身重叠)"""        while True:            x = random.randint(0, (self.width - self.block_size) // self.block_size) * self.block_size            y = random.randint(0, (self.height - self.block_size) // self.block_size) * self.block_size            food = (x, y)            if food not in self.snake:                return food                    def _get_state(self):        """提取游戏状态特征"""        head_x, head_y = self.snake[0]                # 方向特征        dir_left = self.direction == (-self.block_size, 0)        dir_right = self.direction == (self.block_size, 0)        dir_up = self.direction == (0, -self.block_size)        dir_down = self.direction == (0, self.block_size)                # 危险特征(撞墙或自身)        danger_straight = (            (dir_right and (head_x + self.block_size >= self.width or (head_x + self.block_size, head_y) in self.snake)) or            (dir_left and (head_x - self.block_size < 0 or (head_x - self.block_size, head_y) in self.snake)) or            (dir_up and (head_y - self.block_size < 0 or (head_x, head_y - self.block_size) in self.snake)) or            (dir_down and (head_y + self.block_size >= self.height or (head_x, head_y + self.block_size) in self.snake))        )                danger_right = (            (dir_up and (head_x + self.block_size >= self.width or (head_x + self.block_size, head_y) in self.snake)) or            (dir_down and (head_x - self.block_size < 0 or (head_x - self.block_size, head_y) in self.snake)) or            (dir_left and (head_y - self.block_size < 0 or (head_x, head_y - self.block_size) in self.snake)) or            (dir_right and (head_y + self.block_size >= self.height or (head_x, head_y + self.block_size) in self.snake))        )                danger_left = (            (dir_down and (head_x + self.block_size >= self.width or (head_x + self.block_size, head_y) in self.snake)) or            (dir_up and (head_x - self.block_size < 0 or (head_x - self.block_size, head_y) in self.snake)) or            (dir_right and (head_y - self.block_size < 0 or (head_x, head_y - self.block_size) in self.snake)) or            (dir_left and (head_y + self.block_size >= self.height or (head_x, head_y + self.block_size) in self.snake))        )                # 食物位置特征        food_up = head_y > self.food[1]        food_down = head_y < self.food[1]        food_left = head_x > self.food[0]        food_right = head_x < self.food[0]                # 合并状态特征(11维)        return np.array([            danger_straight, danger_right, danger_left,            dir_left, dir_right, dir_up, dir_down,            food_left, food_right, food_up, food_down        ], dtype=int)            def step(self, action):        """执行动作并返回新状态、奖励、是否结束"""        # 动作映射:0=直行,1=右转,2=左转        dx, dy = self.direction        if action == 1:  # 右转            dx, dy = dy, -dx        elif action == 2:  # 左转            dx, dy = -dy, dx        self.direction = (dx, dy)                # 移动蛇头        head_x, head_y = self.snake[0]        new_head = (head_x + dx, head_y + dy)                # 检查碰撞        if (new_head[0] < 0 or new_head[0] >= self.width or            new_head[1] < 0 or new_head[1] >= self.height or            new_head in self.snake):            self.done = True            reward = -10  # 碰撞惩罚        else:            # 移动蛇身            self.snake.insert(0, new_head)            # 检查是否吃到食物            if new_head == self.food:                self.score += 1                reward = 10  # 吃到食物奖励                self.food = self._generate_food()            else:                self.snake.pop()  # 没吃到食物则移除尾部                reward = 0  # 普通步骤无奖励                        return self._get_state(), reward, self.done

游戏环境将蛇的位置、方向、食物位置和危险信息编码为 11 维状态向量,作为 DQN 的输入。

2. 训练过程

def train_snake(episodes=500):    """训练贪吃蛇AI"""    env = SnakeGame()    state_dim = 11  # 状态特征维度    action_dim = 3  # 直行、右转、左转    agent = DQNAgent(state_dim, action_dim)        scores = []    target_update_freq = 10  # 每10回合更新一次目标网络        for episode in range(episodes):        state = env.reset()        total_reward = 0        done = False                while not done:            # 选择动作            action = agent.select_action(state)            # 执行动作            next_state, reward, done = env.step(action)            # 存储经验            agent.buffer.add(state, action, reward, next_state, done)            # 更新网络            agent.update()                        state = next_state            total_reward += reward                    scores.append(env.score)        # 衰减探索率        if agent.epsilon > agent.epsilon_min:            agent.epsilon *= agent.epsilon_decay                    # 定期更新目标网络        if episode % target_update_freq == 0:            agent.target_net.load_state_dict(agent.policy_net.state_dict())                    # 打印进度        if (episode + 1) % 10 == 0:            avg_score = np.mean(scores[-10:])            print(f"回合 {episode+1}/{episodes}, 分数: {env.score}, 平均分数: {avg_score:.1f}, ε: {agent.epsilon:.3f}")        # 保存模型    torch.save(agent.policy_net.state_dict(), "snake_dqn.pth")    return scores# 开始训练scores = train_snake()# 绘制训练曲线import matplotlib.pyplot as pltplt.plot(scores)plt.xlabel("回合")plt.ylabel("分数")plt.title("贪吃蛇DQN训练曲线")plt.show()

训练过程中,AI 从随机碰撞逐步学会避障、追踪食物,分数随回合数稳步提升,体现了 DQN 的自主学习能力。

三、DQN 在复杂游戏中的应用

对于 Atari 等像素级游戏,需要使用 CNN 处理图像输入,并增加预处理步骤:

def preprocess_frame(frame):    """预处理游戏画面"""    # 转为灰度图    gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)    # 缩小尺寸    resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)    # 二值化    _, binary = cv2.threshold(resized, 1, 255, cv2.THRESH_BINARY)    # 归一化    return binary / 255.0class AtariAgent(DQNAgent):    """适用于Atari游戏的DQN智能体"""    def __init__(self, action_dim):        # 使用CNN网络        self.policy_net = DQN_CNN(action_dim)        self.target</doubaocanvas>

举报

相关推荐

0 条评论